cli_argument_parser.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser
  2. from dataclasses import fields
  3. from typing import List
  4. class CliArgumentParser(ArgumentParser):
  5. """ Argument Parser to define and parse command-line args for training.
  6. Args:
  7. training_args: dict or list of dict which defines different
  8. parameters for training.
  9. """
  10. def __init__(self, training_args=None, **kwargs):
  11. if 'formatter_class' not in kwargs:
  12. kwargs['formatter_class'] = ArgumentDefaultsHelpFormatter
  13. super().__init__(**kwargs)
  14. self.training_args = training_args
  15. self.define_args()
  16. def get_manual_args(self, args):
  17. return [arg[2:] for arg in args if arg.startswith('--')]
  18. def _parse_known_args(self,
  19. args: List = None,
  20. namespace=None,
  21. *args_extra,
  22. **kwargs):
  23. self.model_id = namespace.model if namespace is not None else None
  24. if '--model' in args:
  25. self.model_id = args[args.index('--model') + 1]
  26. self.manual_args = self.get_manual_args(args)
  27. return super()._parse_known_args(args, namespace, *args_extra,
  28. **kwargs)
  29. def print_help(self, file=None):
  30. return super().print_help(file)
  31. def define_args(self):
  32. if self.training_args is not None:
  33. for f in fields(self.training_args):
  34. arg_name = f.name
  35. arg_attr = getattr(self.training_args, f.name)
  36. name = f'--{arg_name}'
  37. kwargs = dict(type=f.type, help=f.metadata['help'])
  38. kwargs['default'] = arg_attr
  39. if 'choices' in f.metadata:
  40. kwargs['choices'] = f.metadata['choices']
  41. kwargs['action'] = SingleAction
  42. self.add_argument(name, **kwargs)
  43. class DictAction(Action):
  44. """
  45. argparse action to split an argument into KEY=VALUE form
  46. on the first = and append to a dictionary. List options can
  47. be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
  48. brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
  49. list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
  50. """
  51. @staticmethod
  52. def parse_int_float_bool_str(val):
  53. try:
  54. return int(val)
  55. except ValueError:
  56. pass
  57. try:
  58. return float(val)
  59. except ValueError:
  60. pass
  61. if val.lower() in ['true', 'false']:
  62. return val.lower() == 'true'
  63. if val == 'None':
  64. return None
  65. return val
  66. @staticmethod
  67. def parse_iterable(val):
  68. """Parse iterable values in the string.
  69. All elements inside '()' or '[]' are treated as iterable values.
  70. Args:
  71. val (str): Value string.
  72. Returns:
  73. list | tuple: The expanded list or tuple from the string.
  74. Examples:
  75. >>> DictAction._parse_iterable('1,2,3')
  76. [1, 2, 3]
  77. >>> DictAction._parse_iterable('[a, b, c]')
  78. ['a', 'b', 'c']
  79. >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
  80. [(1, 2, 3), ['a', 'b'], 'c']
  81. """
  82. def find_next_comma(string):
  83. """Find the position of next comma in the string.
  84. If no ',' is found in the string, return the string length. All
  85. chars inside '()' and '[]' are treated as one element and thus ','
  86. inside these brackets are ignored.
  87. """
  88. assert (string.count('(') == string.count(')')) and (
  89. string.count('[')
  90. == string.count(']')), f'Imbalanced brackets exist in {string}'
  91. end = len(string)
  92. for idx, char in enumerate(string):
  93. pre = string[:idx]
  94. # The string before this ',' is balanced
  95. if ((char == ',') and (pre.count('(') == pre.count(')'))
  96. and (pre.count('[') == pre.count(']'))):
  97. end = idx
  98. break
  99. return end
  100. # Strip ' and " characters and replace whitespace.
  101. val = val.strip('\'\"').replace(' ', '')
  102. is_tuple = False
  103. if val.startswith('(') and val.endswith(')'):
  104. is_tuple = True
  105. val = val[1:-1]
  106. elif val.startswith('[') and val.endswith(']'):
  107. val = val[1:-1]
  108. elif ',' not in val:
  109. # val is a single value
  110. return DictAction.parse_int_float_bool_str(val)
  111. values = []
  112. while len(val) > 0:
  113. comma_idx = find_next_comma(val)
  114. element = DictAction.parse_iterable(val[:comma_idx])
  115. values.append(element)
  116. val = val[comma_idx + 1:]
  117. if is_tuple:
  118. values = tuple(values)
  119. return values
  120. def __call__(self, parser, namespace, values, option_string):
  121. options = {}
  122. for kv in values:
  123. key, val = kv.split('=', maxsplit=1)
  124. options[key] = self.parse_iterable(val)
  125. setattr(namespace, self.dest, options)
  126. class SingleAction(DictAction):
  127. """ Argparse action to convert value to tuple or list or nested structure of
  128. list and tuple, i.e 'V1,V2,V3', or with explicit brackets, i.e. '[V1,V2,V3]'.
  129. It also support nested brackets to build list/tuple values. e.g. '[(V1,V2),(V3,V4)]'
  130. """
  131. def __call__(self, parser, namespace, value, option_string):
  132. if isinstance(value, str):
  133. setattr(namespace, self.dest, self.parse_iterable(value))
  134. else:
  135. setattr(namespace, self.dest, value)