config.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import numpy as np
  2. import os
  3. import sys
  4. import platform
  5. import yaml
  6. import time
  7. import shutil
  8. import paddle
  9. import paddle.distributed as dist
  10. from tqdm import tqdm
  11. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  12. from utils import get_logger, print_dict
  13. class ArgsParser(ArgumentParser):
  14. def __init__(self):
  15. super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
  16. self.add_argument("-c", "--config", help="configuration file to use")
  17. self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
  18. self.add_argument(
  19. "-p",
  20. "--profiler_options",
  21. type=str,
  22. default=None,
  23. help='The option of profiler, which should be in format "key1=value1;key2=value2;key3=value3".',
  24. )
  25. def parse_args(self, argv=None):
  26. args = super(ArgsParser, self).parse_args(argv)
  27. assert args.config is not None, "Please specify --config=configure_file_path."
  28. args.opt = self._parse_opt(args.opt)
  29. return args
  30. def _parse_opt(self, opts):
  31. config = {}
  32. if not opts:
  33. return config
  34. for s in opts:
  35. s = s.strip()
  36. k, v = s.split("=")
  37. config[k] = yaml.load(v, Loader=yaml.Loader)
  38. return config
  39. class AttrDict(dict):
  40. """Single level attribute dict, NOT recursive"""
  41. def __init__(self, **kwargs):
  42. super(AttrDict, self).__init__()
  43. super(AttrDict, self).update(kwargs)
  44. def __getattr__(self, key):
  45. if key in self:
  46. return self[key]
  47. raise AttributeError("object has no attribute '{}'".format(key))
  48. global_config = AttrDict()
  49. default_config = {
  50. "Global": {
  51. "debug": False,
  52. }
  53. }
  54. def load_config(file_path):
  55. """
  56. Load config from yml/yaml file.
  57. Args:
  58. file_path (str): Path of the config file to be loaded.
  59. Returns: global config
  60. """
  61. merge_config(default_config)
  62. _, ext = os.path.splitext(file_path)
  63. assert ext in [".yml", ".yaml"], "only support yaml files for now"
  64. merge_config(yaml.load(open(file_path, "rb"), Loader=yaml.Loader))
  65. return global_config
  66. def merge_config(config):
  67. """
  68. Merge config into global config.
  69. Args:
  70. config (dict): Config to be merged.
  71. Returns: global config
  72. """
  73. for key, value in config.items():
  74. if "." not in key:
  75. if isinstance(value, dict) and key in global_config:
  76. global_config[key].update(value)
  77. else:
  78. global_config[key] = value
  79. else:
  80. sub_keys = key.split(".")
  81. assert (
  82. sub_keys[0] in global_config
  83. ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
  84. global_config.keys(), sub_keys[0]
  85. )
  86. cur = global_config[sub_keys[0]]
  87. for idx, sub_key in enumerate(sub_keys[1:]):
  88. if idx == len(sub_keys) - 2:
  89. cur[sub_key] = value
  90. else:
  91. cur = cur[sub_key]
  92. def preprocess(is_train=False):
  93. FLAGS = ArgsParser().parse_args()
  94. profiler_options = FLAGS.profiler_options
  95. config = load_config(FLAGS.config)
  96. merge_config(FLAGS.opt)
  97. profile_dic = {"profiler_options": FLAGS.profiler_options}
  98. merge_config(profile_dic)
  99. if is_train:
  100. # save_config
  101. save_model_dir = config["save_model_dir"]
  102. os.makedirs(save_model_dir, exist_ok=True)
  103. with open(os.path.join(save_model_dir, "config.yml"), "w") as f:
  104. yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
  105. log_file = "{}/train.log".format(save_model_dir)
  106. else:
  107. log_file = None
  108. logger = get_logger(log_file=log_file)
  109. # check if set use_gpu=True in paddlepaddle cpu version
  110. use_gpu = config["use_gpu"]
  111. print_dict(config, logger)
  112. return config, logger
  113. if __name__ == "__main__":
  114. config, logger = preprocess(is_train=False)
  115. # print(config)