args.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import json
  4. def str2bool(v):
  5. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  6. return True
  7. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  8. return False
  9. else:
  10. raise argparse.ArgumentTypeError('Unsupported value encountered.')
  11. class HParams(dict):
  12. """ Hyper-parameters class
  13. Store hyper-parameters in training / infer / ... scripts.
  14. """
  15. def __getattr__(self, name):
  16. if name in self.keys():
  17. return self[name]
  18. for v in self.values():
  19. if isinstance(v, HParams):
  20. if name in v:
  21. return v[name]
  22. raise AttributeError(f"'HParams' object has no attribute '{name}'")
  23. def __setattr__(self, name, value):
  24. self[name] = value
  25. def save(self, filename):
  26. with open(filename, 'w', encoding='utf-8') as fp:
  27. json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
  28. def load(self, filename):
  29. with open(filename, 'r', encoding='utf-8') as fp:
  30. params_dict = json.load(fp)
  31. for k, v in params_dict.items():
  32. if isinstance(v, dict):
  33. self[k].update(HParams(v))
  34. else:
  35. self[k] = v
  36. def parse_args(parser):
  37. """ Parse hyper-parameters from cmdline. """
  38. parsed = parser.parse_args()
  39. args = HParams()
  40. optional_args = parser._action_groups[1]
  41. for action in optional_args._group_actions[1:]:
  42. arg_name = action.dest
  43. args[arg_name] = getattr(parsed, arg_name)
  44. for group in parser._action_groups[2:]:
  45. group_args = HParams()
  46. for action in group._group_actions:
  47. arg_name = action.dest
  48. group_args[arg_name] = getattr(parsed, arg_name)
  49. if len(group_args) > 0:
  50. args[group.title] = group_args
  51. return args