| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- #!/usr/bin/env python
- # Copyright 2021 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import os
- from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
- from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
- from ...utils.imports import is_boto3_available
- from .config_args import SageMakerConfig
- from .config_utils import (
- DYNAMO_BACKENDS,
- _ask_field,
- _ask_options,
- _convert_dynamo_backend,
- _convert_mixed_precision,
- _convert_sagemaker_distributed_mode,
- _convert_yes_no_to_bool,
- )
- if is_boto3_available():
- import boto3 # noqa: F401
- def _create_iam_role_for_sagemaker(role_name):
- iam_client = boto3.client("iam")
- sagemaker_trust_policy = {
- "Version": "2012-10-17",
- "Statement": [
- {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
- ],
- }
- try:
- # create the role, associated with the chosen trust policy
- iam_client.create_role(
- RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
- )
- policy_document = {
- "Version": "2012-10-17",
- "Statement": [
- {
- "Effect": "Allow",
- "Action": [
- "sagemaker:*",
- "ecr:GetDownloadUrlForLayer",
- "ecr:BatchGetImage",
- "ecr:BatchCheckLayerAvailability",
- "ecr:GetAuthorizationToken",
- "cloudwatch:PutMetricData",
- "cloudwatch:GetMetricData",
- "cloudwatch:GetMetricStatistics",
- "cloudwatch:ListMetrics",
- "logs:CreateLogGroup",
- "logs:CreateLogStream",
- "logs:DescribeLogStreams",
- "logs:PutLogEvents",
- "logs:GetLogEvents",
- "s3:CreateBucket",
- "s3:ListBucket",
- "s3:GetBucketLocation",
- "s3:GetObject",
- "s3:PutObject",
- ],
- "Resource": "*",
- }
- ],
- }
- # attach policy to role
- iam_client.put_role_policy(
- RoleName=role_name,
- PolicyName=f"{role_name}_policy_permission",
- PolicyDocument=json.dumps(policy_document, indent=2),
- )
- except iam_client.exceptions.EntityAlreadyExistsException:
- print(f"role {role_name} already exists. Using existing one")
- def _get_iam_role_arn(role_name):
- iam_client = boto3.client("iam")
- return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
- def get_sagemaker_input():
- credentials_configuration = _ask_options(
- "How do you want to authorize?",
- ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
- int,
- )
- aws_profile = None
- if credentials_configuration == 0:
- aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
- os.environ["AWS_PROFILE"] = aws_profile
- else:
- print(
- "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
- "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
- )
- aws_access_key_id = _ask_field("AWS Access Key ID: ")
- os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
- aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
- os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
- aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
- os.environ["AWS_DEFAULT_REGION"] = aws_region
- role_management = _ask_options(
- "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
- ["Provide IAM Role name", "Create new IAM role using credentials"],
- int,
- )
- if role_management == 0:
- iam_role_name = _ask_field("Enter your IAM role name: ")
- else:
- iam_role_name = "accelerate_sagemaker_execution_role"
- print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
- _create_iam_role_for_sagemaker(iam_role_name)
- is_custom_docker_image = _ask_field(
- "Do you want to use custom Docker image? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- docker_image = None
- if is_custom_docker_image:
- docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
- is_sagemaker_inputs_enabled = _ask_field(
- "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- sagemaker_inputs_file = None
- if is_sagemaker_inputs_enabled:
- sagemaker_inputs_file = _ask_field(
- "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
- lambda x: str(x).lower(),
- )
- is_sagemaker_metrics_enabled = _ask_field(
- "Do you want to enable SageMaker metrics? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- sagemaker_metrics_file = None
- if is_sagemaker_metrics_enabled:
- sagemaker_metrics_file = _ask_field(
- "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
- lambda x: str(x).lower(),
- )
- distributed_type = _ask_options(
- "What is the distributed mode?",
- ["No distributed training", "Data parallelism"],
- _convert_sagemaker_distributed_mode,
- )
- dynamo_config = {}
- use_dynamo = _ask_field(
- "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- if use_dynamo:
- prefix = "dynamo_"
- dynamo_config[prefix + "backend"] = _ask_options(
- "Which dynamo backend would you like to use?",
- [x.lower() for x in DYNAMO_BACKENDS],
- _convert_dynamo_backend,
- default=2,
- )
- use_custom_options = _ask_field(
- "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- if use_custom_options:
- dynamo_config[prefix + "mode"] = _ask_options(
- "Which mode do you want to use?",
- TORCH_DYNAMO_MODES,
- lambda x: TORCH_DYNAMO_MODES[int(x)],
- default="default",
- )
- dynamo_config[prefix + "use_fullgraph"] = _ask_field(
- "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- dynamo_config[prefix + "use_dynamic"] = _ask_field(
- "Do you want to enable dynamic shape tracing? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
- "Do you want to enable regional compilation? [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- ec2_instance_query = "Which EC2 instance type you want to use for your training?"
- if distributed_type != SageMakerDistributedType.NO:
- ec2_instance_type = _ask_options(
- ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
- )
- else:
- ec2_instance_query += "? [ml.p3.2xlarge]:"
- ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
- debug = False
- if distributed_type != SageMakerDistributedType.NO:
- debug = _ask_field(
- "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
- _convert_yes_no_to_bool,
- default=False,
- error_message="Please enter yes or no.",
- )
- num_machines = 1
- if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
- num_machines = _ask_field(
- "How many machines do you want use? [1]: ",
- int,
- default=1,
- )
- mixed_precision = _ask_options(
- "Do you wish to use FP16 or BF16 (mixed precision)?",
- ["no", "fp16", "bf16", "fp8"],
- _convert_mixed_precision,
- )
- if use_dynamo and mixed_precision == "no":
- print(
- "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
- )
- return SageMakerConfig(
- image_uri=docker_image,
- compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
- distributed_type=distributed_type,
- use_cpu=False,
- dynamo_config=dynamo_config,
- ec2_instance_type=ec2_instance_type,
- profile=aws_profile,
- region=aws_region,
- iam_role_name=iam_role_name,
- mixed_precision=mixed_precision,
- num_machines=num_machines,
- sagemaker_inputs_file=sagemaker_inputs_file,
- sagemaker_metrics_file=sagemaker_metrics_file,
- debug=debug,
- )
|