| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- #!/usr/bin/env python
- # Copyright 2021 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- from ...utils.dataclasses import (
- ComputeEnvironment,
- DistributedType,
- DynamoBackend,
- FP8BackendType,
- PrecisionType,
- SageMakerDistributedType,
- )
- from ..menu import BulletMenu
- DYNAMO_BACKENDS = [
- "EAGER",
- "AOT_EAGER",
- "INDUCTOR",
- "AOT_TS_NVFUSER",
- "NVPRIMS_NVFUSER",
- "CUDAGRAPHS",
- "OFI",
- "FX2TRT",
- "ONNXRT",
- "TENSORRT",
- "AOT_TORCHXLA_TRACE_ONCE",
- "TORHCHXLA_TRACE_ONCE",
- "IPEX",
- "TVM",
- ]
- def _ask_field(input_text, convert_value=None, default=None, error_message=None):
- ask_again = True
- while ask_again:
- result = input(input_text)
- try:
- if default is not None and len(result) == 0:
- return default
- return convert_value(result) if convert_value is not None else result
- except Exception:
- if error_message is not None:
- print(error_message)
- def _ask_options(input_text, options=[], convert_value=None, default=0):
- menu = BulletMenu(input_text, options)
- result = menu.run(default_choice=default)
- return convert_value(result) if convert_value is not None else result
- def _convert_compute_environment(value):
- value = int(value)
- return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
- def _convert_distributed_mode(value):
- value = int(value)
- return DistributedType(
- [
- "NO",
- "MULTI_CPU",
- "MULTI_XPU",
- "MULTI_HPU",
- "MULTI_GPU",
- "MULTI_NPU",
- "MULTI_MLU",
- "MULTI_SDAA",
- "MULTI_MUSA",
- "XLA",
- ][value]
- )
- def _convert_dynamo_backend(value):
- value = int(value)
- return DynamoBackend(DYNAMO_BACKENDS[value]).value
- def _convert_mixed_precision(value):
- value = int(value)
- return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
- def _convert_sagemaker_distributed_mode(value):
- value = int(value)
- return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
- def _convert_fp8_backend(value):
- value = int(value)
- return FP8BackendType(["TE", "MSAMP"][value])
- def _convert_yes_no_to_bool(value):
- return {"yes": True, "no": False}[value.lower()]
- class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
- """
- A custom formatter that will remove the usage line from the help message for subcommands.
- """
- def _format_usage(self, usage, actions, groups, prefix):
- usage = super()._format_usage(usage, actions, groups, prefix)
- usage = usage.replace("<command> [<args>] ", "")
- return usage
|