config_utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. from ...utils.dataclasses import (
  17. ComputeEnvironment,
  18. DistributedType,
  19. DynamoBackend,
  20. FP8BackendType,
  21. PrecisionType,
  22. SageMakerDistributedType,
  23. )
  24. from ..menu import BulletMenu
  25. DYNAMO_BACKENDS = [
  26. "EAGER",
  27. "AOT_EAGER",
  28. "INDUCTOR",
  29. "AOT_TS_NVFUSER",
  30. "NVPRIMS_NVFUSER",
  31. "CUDAGRAPHS",
  32. "OFI",
  33. "FX2TRT",
  34. "ONNXRT",
  35. "TENSORRT",
  36. "AOT_TORCHXLA_TRACE_ONCE",
  37. "TORHCHXLA_TRACE_ONCE",
  38. "IPEX",
  39. "TVM",
  40. ]
  41. def _ask_field(input_text, convert_value=None, default=None, error_message=None):
  42. ask_again = True
  43. while ask_again:
  44. result = input(input_text)
  45. try:
  46. if default is not None and len(result) == 0:
  47. return default
  48. return convert_value(result) if convert_value is not None else result
  49. except Exception:
  50. if error_message is not None:
  51. print(error_message)
  52. def _ask_options(input_text, options=[], convert_value=None, default=0):
  53. menu = BulletMenu(input_text, options)
  54. result = menu.run(default_choice=default)
  55. return convert_value(result) if convert_value is not None else result
  56. def _convert_compute_environment(value):
  57. value = int(value)
  58. return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
  59. def _convert_distributed_mode(value):
  60. value = int(value)
  61. return DistributedType(
  62. [
  63. "NO",
  64. "MULTI_CPU",
  65. "MULTI_XPU",
  66. "MULTI_HPU",
  67. "MULTI_GPU",
  68. "MULTI_NPU",
  69. "MULTI_MLU",
  70. "MULTI_SDAA",
  71. "MULTI_MUSA",
  72. "XLA",
  73. ][value]
  74. )
  75. def _convert_dynamo_backend(value):
  76. value = int(value)
  77. return DynamoBackend(DYNAMO_BACKENDS[value]).value
  78. def _convert_mixed_precision(value):
  79. value = int(value)
  80. return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
  81. def _convert_sagemaker_distributed_mode(value):
  82. value = int(value)
  83. return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
  84. def _convert_fp8_backend(value):
  85. value = int(value)
  86. return FP8BackendType(["TE", "MSAMP"][value])
  87. def _convert_yes_no_to_bool(value):
  88. return {"yes": True, "no": False}[value.lower()]
  89. class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
  90. """
  91. A custom formatter that will remove the usage line from the help message for subcommands.
  92. """
  93. def _format_usage(self, usage, actions, groups, prefix):
  94. usage = super()._format_usage(usage, actions, groups, prefix)
  95. usage = usage.replace("<command> [<args>] ", "")
  96. return usage