sagemaker.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import json
  16. import os
  17. from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
  18. from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
  19. from ...utils.imports import is_boto3_available
  20. from .config_args import SageMakerConfig
  21. from .config_utils import (
  22. DYNAMO_BACKENDS,
  23. _ask_field,
  24. _ask_options,
  25. _convert_dynamo_backend,
  26. _convert_mixed_precision,
  27. _convert_sagemaker_distributed_mode,
  28. _convert_yes_no_to_bool,
  29. )
  30. if is_boto3_available():
  31. import boto3 # noqa: F401
  32. def _create_iam_role_for_sagemaker(role_name):
  33. iam_client = boto3.client("iam")
  34. sagemaker_trust_policy = {
  35. "Version": "2012-10-17",
  36. "Statement": [
  37. {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
  38. ],
  39. }
  40. try:
  41. # create the role, associated with the chosen trust policy
  42. iam_client.create_role(
  43. RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
  44. )
  45. policy_document = {
  46. "Version": "2012-10-17",
  47. "Statement": [
  48. {
  49. "Effect": "Allow",
  50. "Action": [
  51. "sagemaker:*",
  52. "ecr:GetDownloadUrlForLayer",
  53. "ecr:BatchGetImage",
  54. "ecr:BatchCheckLayerAvailability",
  55. "ecr:GetAuthorizationToken",
  56. "cloudwatch:PutMetricData",
  57. "cloudwatch:GetMetricData",
  58. "cloudwatch:GetMetricStatistics",
  59. "cloudwatch:ListMetrics",
  60. "logs:CreateLogGroup",
  61. "logs:CreateLogStream",
  62. "logs:DescribeLogStreams",
  63. "logs:PutLogEvents",
  64. "logs:GetLogEvents",
  65. "s3:CreateBucket",
  66. "s3:ListBucket",
  67. "s3:GetBucketLocation",
  68. "s3:GetObject",
  69. "s3:PutObject",
  70. ],
  71. "Resource": "*",
  72. }
  73. ],
  74. }
  75. # attach policy to role
  76. iam_client.put_role_policy(
  77. RoleName=role_name,
  78. PolicyName=f"{role_name}_policy_permission",
  79. PolicyDocument=json.dumps(policy_document, indent=2),
  80. )
  81. except iam_client.exceptions.EntityAlreadyExistsException:
  82. print(f"role {role_name} already exists. Using existing one")
  83. def _get_iam_role_arn(role_name):
  84. iam_client = boto3.client("iam")
  85. return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
  86. def get_sagemaker_input():
  87. credentials_configuration = _ask_options(
  88. "How do you want to authorize?",
  89. ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
  90. int,
  91. )
  92. aws_profile = None
  93. if credentials_configuration == 0:
  94. aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
  95. os.environ["AWS_PROFILE"] = aws_profile
  96. else:
  97. print(
  98. "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
  99. "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
  100. )
  101. aws_access_key_id = _ask_field("AWS Access Key ID: ")
  102. os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
  103. aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
  104. os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
  105. aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
  106. os.environ["AWS_DEFAULT_REGION"] = aws_region
  107. role_management = _ask_options(
  108. "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
  109. ["Provide IAM Role name", "Create new IAM role using credentials"],
  110. int,
  111. )
  112. if role_management == 0:
  113. iam_role_name = _ask_field("Enter your IAM role name: ")
  114. else:
  115. iam_role_name = "accelerate_sagemaker_execution_role"
  116. print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
  117. _create_iam_role_for_sagemaker(iam_role_name)
  118. is_custom_docker_image = _ask_field(
  119. "Do you want to use custom Docker image? [yes/NO]: ",
  120. _convert_yes_no_to_bool,
  121. default=False,
  122. error_message="Please enter yes or no.",
  123. )
  124. docker_image = None
  125. if is_custom_docker_image:
  126. docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
  127. is_sagemaker_inputs_enabled = _ask_field(
  128. "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
  129. _convert_yes_no_to_bool,
  130. default=False,
  131. error_message="Please enter yes or no.",
  132. )
  133. sagemaker_inputs_file = None
  134. if is_sagemaker_inputs_enabled:
  135. sagemaker_inputs_file = _ask_field(
  136. "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
  137. lambda x: str(x).lower(),
  138. )
  139. is_sagemaker_metrics_enabled = _ask_field(
  140. "Do you want to enable SageMaker metrics? [yes/NO]: ",
  141. _convert_yes_no_to_bool,
  142. default=False,
  143. error_message="Please enter yes or no.",
  144. )
  145. sagemaker_metrics_file = None
  146. if is_sagemaker_metrics_enabled:
  147. sagemaker_metrics_file = _ask_field(
  148. "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
  149. lambda x: str(x).lower(),
  150. )
  151. distributed_type = _ask_options(
  152. "What is the distributed mode?",
  153. ["No distributed training", "Data parallelism"],
  154. _convert_sagemaker_distributed_mode,
  155. )
  156. dynamo_config = {}
  157. use_dynamo = _ask_field(
  158. "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
  159. _convert_yes_no_to_bool,
  160. default=False,
  161. error_message="Please enter yes or no.",
  162. )
  163. if use_dynamo:
  164. prefix = "dynamo_"
  165. dynamo_config[prefix + "backend"] = _ask_options(
  166. "Which dynamo backend would you like to use?",
  167. [x.lower() for x in DYNAMO_BACKENDS],
  168. _convert_dynamo_backend,
  169. default=2,
  170. )
  171. use_custom_options = _ask_field(
  172. "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
  173. _convert_yes_no_to_bool,
  174. default=False,
  175. error_message="Please enter yes or no.",
  176. )
  177. if use_custom_options:
  178. dynamo_config[prefix + "mode"] = _ask_options(
  179. "Which mode do you want to use?",
  180. TORCH_DYNAMO_MODES,
  181. lambda x: TORCH_DYNAMO_MODES[int(x)],
  182. default="default",
  183. )
  184. dynamo_config[prefix + "use_fullgraph"] = _ask_field(
  185. "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
  186. _convert_yes_no_to_bool,
  187. default=False,
  188. error_message="Please enter yes or no.",
  189. )
  190. dynamo_config[prefix + "use_dynamic"] = _ask_field(
  191. "Do you want to enable dynamic shape tracing? [yes/NO]: ",
  192. _convert_yes_no_to_bool,
  193. default=False,
  194. error_message="Please enter yes or no.",
  195. )
  196. dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
  197. "Do you want to enable regional compilation? [yes/NO]: ",
  198. _convert_yes_no_to_bool,
  199. default=False,
  200. error_message="Please enter yes or no.",
  201. )
  202. ec2_instance_query = "Which EC2 instance type you want to use for your training?"
  203. if distributed_type != SageMakerDistributedType.NO:
  204. ec2_instance_type = _ask_options(
  205. ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
  206. )
  207. else:
  208. ec2_instance_query += "? [ml.p3.2xlarge]:"
  209. ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
  210. debug = False
  211. if distributed_type != SageMakerDistributedType.NO:
  212. debug = _ask_field(
  213. "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
  214. _convert_yes_no_to_bool,
  215. default=False,
  216. error_message="Please enter yes or no.",
  217. )
  218. num_machines = 1
  219. if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
  220. num_machines = _ask_field(
  221. "How many machines do you want use? [1]: ",
  222. int,
  223. default=1,
  224. )
  225. mixed_precision = _ask_options(
  226. "Do you wish to use FP16 or BF16 (mixed precision)?",
  227. ["no", "fp16", "bf16", "fp8"],
  228. _convert_mixed_precision,
  229. )
  230. if use_dynamo and mixed_precision == "no":
  231. print(
  232. "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
  233. )
  234. return SageMakerConfig(
  235. image_uri=docker_image,
  236. compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
  237. distributed_type=distributed_type,
  238. use_cpu=False,
  239. dynamo_config=dynamo_config,
  240. ec2_instance_type=ec2_instance_type,
  241. profile=aws_profile,
  242. region=aws_region,
  243. iam_role_name=iam_role_name,
  244. mixed_precision=mixed_precision,
  245. num_machines=num_machines,
  246. sagemaker_inputs_file=sagemaker_inputs_file,
  247. sagemaker_metrics_file=sagemaker_metrics_file,
  248. debug=debug,
  249. )