| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser
- from dataclasses import fields
- from typing import List
- class CliArgumentParser(ArgumentParser):
- """ Argument Parser to define and parse command-line args for training.
- Args:
- training_args: dict or list of dict which defines different
- parameters for training.
- """
- def __init__(self, training_args=None, **kwargs):
- if 'formatter_class' not in kwargs:
- kwargs['formatter_class'] = ArgumentDefaultsHelpFormatter
- super().__init__(**kwargs)
- self.training_args = training_args
- self.define_args()
- def get_manual_args(self, args):
- return [arg[2:] for arg in args if arg.startswith('--')]
- def _parse_known_args(self,
- args: List = None,
- namespace=None,
- *args_extra,
- **kwargs):
- self.model_id = namespace.model if namespace is not None else None
- if '--model' in args:
- self.model_id = args[args.index('--model') + 1]
- self.manual_args = self.get_manual_args(args)
- return super()._parse_known_args(args, namespace, *args_extra,
- **kwargs)
- def print_help(self, file=None):
- return super().print_help(file)
- def define_args(self):
- if self.training_args is not None:
- for f in fields(self.training_args):
- arg_name = f.name
- arg_attr = getattr(self.training_args, f.name)
- name = f'--{arg_name}'
- kwargs = dict(type=f.type, help=f.metadata['help'])
- kwargs['default'] = arg_attr
- if 'choices' in f.metadata:
- kwargs['choices'] = f.metadata['choices']
- kwargs['action'] = SingleAction
- self.add_argument(name, **kwargs)
- class DictAction(Action):
- """
- argparse action to split an argument into KEY=VALUE form
- on the first = and append to a dictionary. List options can
- be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
- brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
- list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
- """
- @staticmethod
- def parse_int_float_bool_str(val):
- try:
- return int(val)
- except ValueError:
- pass
- try:
- return float(val)
- except ValueError:
- pass
- if val.lower() in ['true', 'false']:
- return val.lower() == 'true'
- if val == 'None':
- return None
- return val
- @staticmethod
- def parse_iterable(val):
- """Parse iterable values in the string.
- All elements inside '()' or '[]' are treated as iterable values.
- Args:
- val (str): Value string.
- Returns:
- list | tuple: The expanded list or tuple from the string.
- Examples:
- >>> DictAction._parse_iterable('1,2,3')
- [1, 2, 3]
- >>> DictAction._parse_iterable('[a, b, c]')
- ['a', 'b', 'c']
- >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
- [(1, 2, 3), ['a', 'b'], 'c']
- """
- def find_next_comma(string):
- """Find the position of next comma in the string.
- If no ',' is found in the string, return the string length. All
- chars inside '()' and '[]' are treated as one element and thus ','
- inside these brackets are ignored.
- """
- assert (string.count('(') == string.count(')')) and (
- string.count('[')
- == string.count(']')), f'Imbalanced brackets exist in {string}'
- end = len(string)
- for idx, char in enumerate(string):
- pre = string[:idx]
- # The string before this ',' is balanced
- if ((char == ',') and (pre.count('(') == pre.count(')'))
- and (pre.count('[') == pre.count(']'))):
- end = idx
- break
- return end
- # Strip ' and " characters and replace whitespace.
- val = val.strip('\'\"').replace(' ', '')
- is_tuple = False
- if val.startswith('(') and val.endswith(')'):
- is_tuple = True
- val = val[1:-1]
- elif val.startswith('[') and val.endswith(']'):
- val = val[1:-1]
- elif ',' not in val:
- # val is a single value
- return DictAction.parse_int_float_bool_str(val)
- values = []
- while len(val) > 0:
- comma_idx = find_next_comma(val)
- element = DictAction.parse_iterable(val[:comma_idx])
- values.append(element)
- val = val[comma_idx + 1:]
- if is_tuple:
- values = tuple(values)
- return values
- def __call__(self, parser, namespace, values, option_string):
- options = {}
- for kv in values:
- key, val = kv.split('=', maxsplit=1)
- options[key] = self.parse_iterable(val)
- setattr(namespace, self.dest, options)
- class SingleAction(DictAction):
- """ Argparse action to convert value to tuple or list or nested structure of
- list and tuple, i.e 'V1,V2,V3', or with explicit brackets, i.e. '[V1,V2,V3]'.
- It also support nested brackets to build list/tuple values. e.g. '[(V1,V2),(V3,V4)]'
- """
- def __call__(self, parser, namespace, value, option_string):
- if isinstance(value, str):
- setattr(namespace, self.dest, self.parse_iterable(value))
- else:
- setattr(namespace, self.dest, value)
|