config.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Major implementation is borrowed and modified from
  3. # https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
  4. import copy
  5. import os
  6. import os.path as osp
  7. import platform
  8. import shutil
  9. import sys
  10. import tempfile
  11. import types
  12. from pathlib import Path
  13. from types import FunctionType
  14. from typing import Dict, Union
  15. import addict
  16. import json
  17. from modelscope.utils.constant import ConfigFields, ModelFile
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. BASE_KEY = '_base_'
  21. DELETE_KEY = '_delete_'
  22. DEPRECATION_KEY = '_deprecation_'
  23. RESERVED_KEYS = ['filename', 'text', 'pretty_text']
  24. class ConfigDict(addict.Dict):
  25. """ Dict which support get value through getattr
  26. Examples:
  27. >>> cdict = ConfigDict({'a':1232})
  28. >>> print(cdict.a)
  29. >>> # 1232
  30. """
  31. def __missing__(self, name):
  32. raise KeyError(name)
  33. def __getattr__(self, name):
  34. try:
  35. value = super(ConfigDict, self).__getattr__(name)
  36. except KeyError:
  37. ex = AttributeError(f"'{self.__class__.__name__}' object has no "
  38. f"attribute '{name}'")
  39. except Exception as e:
  40. ex = e
  41. else:
  42. return value
  43. raise ex
  44. class Config:
  45. """A facility for config and config files.
  46. It supports common file formats as configs: python/json/yaml. The interface
  47. is the same as a dict object and also allows access config values as
  48. attributes.
  49. Example:
  50. >>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd')))
  51. >>> cfg.a
  52. 1
  53. >>> cfg.b
  54. {'c': [1, 2, 3], 'd': 'dd'}
  55. >>> cfg.b.d
  56. 'dd'
  57. >>> cfg = Config.from_file('configs/examples/configuration.json')
  58. >>> cfg.filename
  59. 'configs/examples/configuration.json'
  60. >>> cfg.b
  61. {'c': [1, 2, 3], 'd': 'dd'}
  62. >>> cfg = Config.from_file('configs/examples/configuration.py')
  63. >>> cfg.filename
  64. "configs/examples/configuration.py"
  65. >>> cfg = Config.from_file('configs/examples/configuration.yaml')
  66. >>> cfg.filename
  67. "configs/examples/configuration.yaml"
  68. """
  69. @staticmethod
  70. def _file2dict(filename):
  71. filename = osp.abspath(osp.expanduser(filename))
  72. if not osp.exists(filename):
  73. raise ValueError(f'File does not exists {filename}')
  74. fileExtname = osp.splitext(filename)[1]
  75. if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
  76. raise IOError('Only py/yml/yaml/json type are supported now!')
  77. with tempfile.TemporaryDirectory() as tmp_cfg_dir:
  78. tmp_cfg_file = tempfile.NamedTemporaryFile(
  79. dir=tmp_cfg_dir, suffix=fileExtname)
  80. if platform.system() == 'Windows':
  81. tmp_cfg_file.close()
  82. tmp_cfg_name = osp.basename(tmp_cfg_file.name)
  83. shutil.copyfile(filename, tmp_cfg_file.name)
  84. if filename.endswith('.py'):
  85. # import as needed.
  86. from modelscope.utils.import_utils import import_modules_from_file
  87. module_nanme, mod = import_modules_from_file(
  88. osp.join(tmp_cfg_dir, tmp_cfg_name))
  89. cfg_dict = {}
  90. for name, value in mod.__dict__.items():
  91. if not name.startswith('__') and \
  92. not isinstance(value, types.ModuleType) and \
  93. not isinstance(value, types.FunctionType):
  94. cfg_dict[name] = value
  95. # delete imported module
  96. del sys.modules[module_nanme]
  97. elif filename.endswith(('.yml', '.yaml', '.json')):
  98. from modelscope.fileio import load
  99. cfg_dict = load(tmp_cfg_file.name)
  100. # close temp file
  101. tmp_cfg_file.close()
  102. cfg_text = filename + '\n'
  103. with open(filename, 'r', encoding='utf-8') as f:
  104. # Setting encoding explicitly to resolve coding issue on windows
  105. cfg_text += f.read()
  106. return cfg_dict, cfg_text
  107. @staticmethod
  108. def from_file(filename):
  109. if isinstance(filename, Path):
  110. filename = str(filename)
  111. cfg_dict, cfg_text = Config._file2dict(filename)
  112. return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
  113. @staticmethod
  114. def from_string(cfg_str, file_format):
  115. """Generate config from config str.
  116. Args:
  117. cfg_str (str): Config str.
  118. file_format (str): Config file format corresponding to the
  119. config str. Only py/yml/yaml/json type are supported now!
  120. Returns:
  121. :obj:`Config`: Config obj.
  122. """
  123. if file_format not in ['.py', '.json', '.yaml', '.yml']:
  124. raise IOError('Only py/yml/yaml/json type are supported now!')
  125. if file_format != '.py' and 'dict(' in cfg_str:
  126. # check if users specify a wrong suffix for python
  127. logger.warning(
  128. 'Please check "file_format", the file format may be .py')
  129. with tempfile.NamedTemporaryFile(
  130. 'w', encoding='utf-8', suffix=file_format,
  131. delete=False) as temp_file:
  132. temp_file.write(cfg_str)
  133. # on windows, previous implementation cause error
  134. # see PR 1077 for details
  135. cfg = Config.from_file(temp_file.name)
  136. os.remove(temp_file.name)
  137. return cfg
  138. def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
  139. if cfg_dict is None:
  140. cfg_dict = dict()
  141. elif not isinstance(cfg_dict, dict):
  142. raise TypeError('cfg_dict must be a dict, but '
  143. f'got {type(cfg_dict)}')
  144. for key in cfg_dict:
  145. if key in RESERVED_KEYS:
  146. raise KeyError(f'{key} is reserved for config file')
  147. if isinstance(filename, Path):
  148. filename = str(filename)
  149. super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
  150. super(Config, self).__setattr__('_filename', filename)
  151. if cfg_text:
  152. text = cfg_text
  153. elif filename:
  154. with open(filename, 'r', encoding='utf-8') as f:
  155. text = f.read()
  156. else:
  157. text = ''
  158. super(Config, self).__setattr__('_text', text)
  159. @property
  160. def filename(self):
  161. return self._filename
  162. @property
  163. def text(self):
  164. return self._text
  165. @property
  166. def pretty_text(self):
  167. indent = 4
  168. def _indent(s_, num_spaces):
  169. s = s_.split('\n')
  170. if len(s) == 1:
  171. return s_
  172. first = s.pop(0)
  173. s = [(num_spaces * ' ') + line for line in s]
  174. s = '\n'.join(s)
  175. s = first + '\n' + s
  176. return s
  177. def _format_basic_types(k, v, use_mapping=False):
  178. if isinstance(v, str):
  179. v_str = f"'{v}'"
  180. else:
  181. v_str = str(v)
  182. if use_mapping:
  183. k_str = f"'{k}'" if isinstance(k, str) else str(k)
  184. attr_str = f'{k_str}: {v_str}'
  185. else:
  186. attr_str = f'{str(k)}={v_str}'
  187. attr_str = _indent(attr_str, indent)
  188. return attr_str
  189. def _format_list(k, v, use_mapping=False):
  190. # check if all items in the list are dict
  191. if all(isinstance(_, dict) for _ in v):
  192. v_str = '[\n'
  193. v_str += '\n'.join(
  194. f'dict({_indent(_format_dict(v_), indent)}),'
  195. for v_ in v).rstrip(',')
  196. if use_mapping:
  197. k_str = f"'{k}'" if isinstance(k, str) else str(k)
  198. attr_str = f'{k_str}: {v_str}'
  199. else:
  200. attr_str = f'{str(k)}={v_str}'
  201. attr_str = _indent(attr_str, indent) + ']'
  202. else:
  203. attr_str = _format_basic_types(k, v, use_mapping)
  204. return attr_str
  205. def _contain_invalid_identifier(dict_str):
  206. contain_invalid_identifier = False
  207. for key_name in dict_str:
  208. contain_invalid_identifier |= \
  209. (not str(key_name).isidentifier())
  210. return contain_invalid_identifier
  211. def _format_dict(input_dict, outest_level=False):
  212. r = ''
  213. s = []
  214. use_mapping = _contain_invalid_identifier(input_dict)
  215. if use_mapping:
  216. r += '{'
  217. for idx, (k, v) in enumerate(input_dict.items()):
  218. is_last = idx >= len(input_dict) - 1
  219. end = '' if outest_level or is_last else ','
  220. if isinstance(v, dict):
  221. v_str = '\n' + _format_dict(v)
  222. if use_mapping:
  223. k_str = f"'{k}'" if isinstance(k, str) else str(k)
  224. attr_str = f'{k_str}: dict({v_str}'
  225. else:
  226. attr_str = f'{str(k)}=dict({v_str}'
  227. attr_str = _indent(attr_str, indent) + ')' + end
  228. elif isinstance(v, list):
  229. attr_str = _format_list(k, v, use_mapping) + end
  230. else:
  231. attr_str = _format_basic_types(k, v, use_mapping) + end
  232. s.append(attr_str)
  233. r += '\n'.join(s)
  234. if use_mapping:
  235. r += '}'
  236. return r
  237. cfg_dict = self._cfg_dict.to_dict()
  238. text = _format_dict(cfg_dict, outest_level=True)
  239. # copied from setup.cfg
  240. yapf_style = dict(
  241. based_on_style='pep8',
  242. blank_line_before_nested_class_or_def=True,
  243. split_before_expression_after_opening_paren=True)
  244. from yapf.yapflib.yapf_api import FormatCode
  245. text, _ = FormatCode(text, style_config=yapf_style, verify=True)
  246. return text
  247. def __repr__(self):
  248. return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
  249. def __len__(self):
  250. return len(self._cfg_dict)
  251. def __getattr__(self, name):
  252. return getattr(self._cfg_dict, name)
  253. def __getitem__(self, name):
  254. return self._cfg_dict.__getitem__(name)
  255. def __setattr__(self, name, value):
  256. if isinstance(value, dict):
  257. value = ConfigDict(value)
  258. self._cfg_dict.__setattr__(name, value)
  259. def __setitem__(self, name, value):
  260. if isinstance(value, dict):
  261. value = ConfigDict(value)
  262. self._cfg_dict.__setitem__(name, value)
  263. def __iter__(self):
  264. return iter(self._cfg_dict)
  265. def __getstate__(self):
  266. return (self._cfg_dict, self._filename, self._text)
  267. def __copy__(self):
  268. cls = self.__class__
  269. other = cls.__new__(cls)
  270. other.__dict__.update(self.__dict__)
  271. return other
  272. def __deepcopy__(self, memo):
  273. cls = self.__class__
  274. other = cls.__new__(cls)
  275. memo[id(self)] = other
  276. for key, value in self.__dict__.items():
  277. super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
  278. return other
  279. def __setstate__(self, state):
  280. _cfg_dict, _filename, _text = state
  281. super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
  282. super(Config, self).__setattr__('_filename', _filename)
  283. super(Config, self).__setattr__('_text', _text)
  284. def safe_get(self, key_chain: str, default=None, type_field='type'):
  285. """Get a value with a key-chain in str format, if key does not exist, the default value will be returned.
  286. This method is safe to call, and will not edit any value.
  287. Args:
  288. key_chain: The input key chain, for example: 'train.hooks[0].type'
  289. default: The default value returned when any key does not exist, default None.
  290. type_field: Get an object from a list or tuple for example by 'train.hooks.CheckPointHook', in which
  291. 'hooks' is a list, and 'CheckPointHook' is a value of the content of key `type_field`.
  292. If there are multiple matched objects, the first element will be returned.
  293. Returns:
  294. The value, or the default value.
  295. """
  296. try:
  297. keys = key_chain.split('.')
  298. _cfg_dict = self._cfg_dict
  299. for key in keys:
  300. val = None
  301. if '[' in key:
  302. key, val = key.split('[')
  303. val, _ = val.split(']')
  304. if isinstance(_cfg_dict, (list, tuple)):
  305. assert type_field is not None, 'Getting object without an index from a list or tuple ' \
  306. 'needs an valid `type_field` param.'
  307. _sub_cfg_dict = list(
  308. filter(lambda sub: sub[type_field] == key, _cfg_dict))
  309. _cfg_dict = _sub_cfg_dict[0]
  310. else:
  311. _cfg_dict = _cfg_dict[key]
  312. if val is not None:
  313. _cfg_dict = _cfg_dict[int(val)]
  314. return _cfg_dict
  315. except Exception as e:
  316. logger.debug(
  317. f'Key not valid in Config: {key_chain}, return the default value: {default}'
  318. )
  319. logger.debug(e)
  320. return default
  321. def dump(self, file: str = None):
  322. """Dumps config into a file or returns a string representation of the
  323. config.
  324. If a file argument is given, saves the config to that file using the
  325. format defined by the file argument extension.
  326. Otherwise, returns a string representing the config. The formatting of
  327. this returned string is defined by the extension of `self.filename`. If
  328. `self.filename` is not defined, returns a string representation of a
  329. dict (lowercased and using ' for strings).
  330. Examples:
  331. >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0),
  332. ... item3=True, item4='test')
  333. >>> cfg = Config(cfg_dict=cfg_dict)
  334. >>> dump_file = "a.py"
  335. >>> cfg.dump(dump_file)
  336. Args:
  337. file (str, optional): Path of the output file where the config
  338. will be dumped. Defaults to None.
  339. """
  340. from modelscope.fileio import dump
  341. cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
  342. if file is None:
  343. if self.filename is None or self.filename.endswith('.py'):
  344. return self.pretty_text
  345. else:
  346. file_format = self.filename.split('.')[-1]
  347. return dump(cfg_dict, file_format=file_format)
  348. elif file.endswith('.py'):
  349. with open(file, 'w', encoding='utf-8') as f:
  350. f.write(self.pretty_text)
  351. else:
  352. file_format = file.split('.')[-1]
  353. return dump(cfg_dict, file=file, file_format=file_format)
  354. def merge_from_dict(self, options, allow_list_keys=True, force=True):
  355. """Merge dict into cfg_dict.
  356. Merge the dict parsed by MultipleKVAction into this cfg.
  357. Examples:
  358. >>> options = {'model.backbone.depth': 50,
  359. ... 'model.backbone.with_cp':True}
  360. >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
  361. >>> cfg.merge_from_dict(options)
  362. >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  363. >>> assert cfg_dict == dict(
  364. ... model=dict(backbone=dict(type='ResNet', depth=50, with_cp=True)))
  365. >>> # Merge list element for replace target index
  366. >>> cfg = Config(dict(pipeline=[
  367. ... dict(type='Resize'), dict(type='RandomDistortion')]))
  368. >>> options = dict(pipeline={'0': dict(type='MyResize')})
  369. >>> cfg.merge_from_dict(options, allow_list_keys=True)
  370. >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  371. >>> assert cfg_dict == dict(pipeline=[
  372. ... dict(type='MyResize'), dict(type='RandomDistortion')])
  373. >>> # Merge list element for replace args and add to list, only support list of type dict with key ``type``,
  374. >>> # if you add new list element, the list does not guarantee the order,
  375. >>> # it is only suitable for the case where the order of the list is not concerned.
  376. >>> cfg = Config(dict(pipeline=[
  377. ... dict(type='Resize', size=224), dict(type='RandomDistortion')]))
  378. >>> options = dict(pipeline=[dict(type='Resize', size=256), dict(type='RandomFlip')])
  379. >>> cfg.merge_from_dict(options, allow_list_keys=True)
  380. >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  381. >>> assert cfg_dict == dict(pipeline=[
  382. ... dict(type='Resize', size=256), dict(type='RandomDistortion'), dict(type='RandomFlip')])
  383. >>> # force usage
  384. >>> options = {'model.backbone.depth': 18,
  385. ... 'model.backbone.with_cp':True}
  386. >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet', depth=50))))
  387. >>> cfg.merge_from_dict(options, force=False)
  388. >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  389. >>> assert cfg_dict == dict(
  390. ... model=dict(backbone=dict(type='ResNet', depth=50, with_cp=True)))
  391. Args:
  392. options (dict): dict of configs to merge from.
  393. allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
  394. are allowed in ``options`` and will replace the element of the
  395. corresponding index in the config if the config is a list.
  396. Or you can directly replace args for list or add new list element,
  397. only support list of type dict with key ``type``,
  398. but if you add new list element, the list does not guarantee the order,
  399. It is only suitable for the case where the order of the list is not concerned.
  400. Default: True.
  401. force (bool): If True, existing key-value will be replaced by new given.
  402. If False, existing key-value will not be updated.
  403. """
  404. option_cfg_dict = {}
  405. for full_key, v in options.items():
  406. d = option_cfg_dict
  407. key_list = full_key.split('.')
  408. for subkey in key_list[:-1]:
  409. d.setdefault(subkey, ConfigDict())
  410. d = d[subkey]
  411. subkey = key_list[-1]
  412. d[subkey] = v
  413. cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  414. super(Config, self).__setattr__(
  415. '_cfg_dict',
  416. Config._merge_a_into_b(
  417. option_cfg_dict,
  418. cfg_dict,
  419. allow_list_keys=allow_list_keys,
  420. force=force))
  421. @staticmethod
  422. def _merge_a_into_b(a, b, allow_list_keys=False, force=True):
  423. """merge dict ``a`` into dict ``b`` (non-inplace).
  424. Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
  425. in-place modifications.
  426. Args:
  427. a (dict): The source dict to be merged into ``b``.
  428. b (dict): The origin dict to be fetch keys from ``a``.
  429. allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
  430. are allowed in source ``a`` and will replace the element of the
  431. corresponding index in b if b is a list. Default: False.
  432. force (bool): If True, existing key-value will be replaced by new given.
  433. If False, existing key-value will not be updated.
  434. Returns:
  435. dict: The modified dict of ``b`` using ``a``.
  436. Examples:
  437. # Normally merge a into b.
  438. >>> Config._merge_a_into_b(
  439. ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
  440. {'obj': {'a': 2}}
  441. # Delete b first and merge a into b.
  442. >>> Config._merge_a_into_b(
  443. ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
  444. {'obj': {'a': 2}}
  445. # b is a list
  446. >>> Config._merge_a_into_b(
  447. ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
  448. [{'a': 2}, {'b': 2}]
  449. # value of a and b are both list, only support list of type dict with key ``type``,
  450. # You can directly replace args for list or add new list element,
  451. # but if you add new list element, the list does not guarantee the order,
  452. # it is only suitable for the case where the order of the list is not concerned.
  453. >>> Config._merge_a_into_b(
  454. ... {'k': [dict(a=2), dict(c=3)]}, {'k': [dict(a=1), dict(b=2)]}, True)
  455. {'k': [dict(a=2), dict(b=2), dict(c=3)]}
  456. # force is False
  457. >>> Config._merge_a_into_b(
  458. ... dict(obj=dict(a=2, b=2)), dict(obj=dict(a=1))), True, force=False)
  459. {'obj': {'a': 1, b=2}}
  460. """
  461. b = b.copy()
  462. for k, v in a.items():
  463. if allow_list_keys and k.isdigit() and isinstance(b, list):
  464. k = int(k)
  465. if len(b) <= k:
  466. raise KeyError(f'Index {k} exceeds the length of list {b}')
  467. b[k] = Config._merge_a_into_b(
  468. v, b[k], allow_list_keys, force=force)
  469. elif allow_list_keys and isinstance(v, list) and k in b:
  470. if not isinstance(b[k], list):
  471. raise ValueError(
  472. f'type mismatch {type(v)} and {type(b[k])} between a and b for key {k}'
  473. )
  474. _is_dict_with_type = True
  475. for list_i in b[k] + v:
  476. if not isinstance(list_i, dict) or 'type' not in list_i:
  477. if k not in b or force:
  478. b[k] = v
  479. _is_dict_with_type = False
  480. if _is_dict_with_type:
  481. res_list = []
  482. added_index_bk, added_index_v = [], []
  483. for i, b_li in enumerate(b[k]):
  484. for j, a_lj in enumerate(v):
  485. if a_lj['type'] == b_li['type']:
  486. res_list.append(
  487. Config._merge_a_into_b(
  488. a_lj,
  489. b_li,
  490. allow_list_keys,
  491. force=force))
  492. added_index_v.append(j)
  493. added_index_bk.append(i)
  494. break
  495. rest_bk = [
  496. b[k][i] for i in range(len(b[k]))
  497. if i not in added_index_bk
  498. ]
  499. rest_v = [
  500. v[i] for i in range(len(v)) if i not in added_index_v
  501. ]
  502. rest = rest_bk + rest_v
  503. res_list += [
  504. Config._merge_a_into_b(
  505. rest[i], {}, allow_list_keys, force=force)
  506. for i in range(len(rest))
  507. ]
  508. b[k] = res_list
  509. elif isinstance(v,
  510. dict) and k in b and not v.pop(DELETE_KEY, False):
  511. allowed_types = (dict, list) if allow_list_keys else dict
  512. if not isinstance(b[k], allowed_types):
  513. raise TypeError(
  514. f'{k}={v} in child config cannot inherit from base '
  515. f'because {k} is a dict in the child config but is of '
  516. f'type {type(b[k])} in base config. You may set '
  517. f'`{DELETE_KEY}=True` to ignore the base config')
  518. b[k] = Config._merge_a_into_b(
  519. v, b[k], allow_list_keys, force=force)
  520. else:
  521. if k not in b or force:
  522. b[k] = v
  523. return b
  524. def to_dict(self) -> Dict:
  525. """ Convert Config object to python dict
  526. """
  527. return self._cfg_dict.to_dict()
  528. def to_args(self, parse_fn, use_hyphen=True):
  529. """ Convert config obj to args using parse_fn
  530. Args:
  531. parse_fn: a function object, which takes args as input,
  532. such as ['--foo', 'FOO'] and return parsed args, an
  533. example is given as follows
  534. including literal blocks::
  535. def parse_fn(args):
  536. parser = argparse.ArgumentParser(prog='PROG')
  537. parser.add_argument('-x')
  538. parser.add_argument('--foo')
  539. return parser.parse_args(args)
  540. use_hyphen (bool, optional): if set true, hyphen in keyname
  541. will be converted to underscore
  542. Return:
  543. args: arg object parsed by argparse.ArgumentParser
  544. """
  545. args = []
  546. for k, v in self._cfg_dict.items():
  547. arg_name = f'--{k}'
  548. if use_hyphen:
  549. arg_name = arg_name.replace('_', '-')
  550. if isinstance(v, bool) and v:
  551. args.append(arg_name)
  552. elif isinstance(v, (int, str, float)):
  553. args.append(arg_name)
  554. args.append(str(v))
  555. elif isinstance(v, list):
  556. args.append(arg_name)
  557. assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \
  558. f'is expected to be either int,str,float, but got type {v[0]}'
  559. args.append(str(v))
  560. else:
  561. raise ValueError(
  562. 'type in config file which supported to be '
  563. 'converted to args should be either bool, '
  564. f'int, str, float or list of them but got type {v}')
  565. return parse_fn(args)
  566. def check_config(cfg: Union[str, ConfigDict], is_training=False):
  567. """ Check whether configuration file is valid, If anything wrong, exception will be raised.
  568. Args:
  569. cfg (str or ConfigDict): Config file path or config object.
  570. is_training: indicate if checking training related elements
  571. """
  572. if isinstance(cfg, str):
  573. cfg = Config.from_file(cfg)
  574. def check_attr(attr_name, msg=''):
  575. assert hasattr(cfg, attr_name), f'Attribute {attr_name} is missing from ' \
  576. f'{ModelFile.CONFIGURATION}. {msg}'
  577. check_attr(ConfigFields.framework)
  578. check_attr(ConfigFields.task)
  579. check_attr(ConfigFields.pipeline)
  580. if is_training:
  581. check_attr(ConfigFields.model)
  582. check_attr(ConfigFields.train)
  583. check_attr(ConfigFields.preprocessor)
  584. check_attr(ConfigFields.evaluation)
  585. class JSONIteratorEncoder(json.JSONEncoder):
  586. """Implement this method in order that supporting arbitrary iterators, it returns
  587. a serializable object for ``obj``, or calls the base implementation
  588. (to raise a ``TypeError``).
  589. """
  590. def default(self, obj):
  591. if isinstance(obj, FunctionType):
  592. return None
  593. try:
  594. iterable = iter(obj)
  595. except TypeError:
  596. pass
  597. else:
  598. return list(iterable)
  599. return json.JSONEncoder.default(self, obj)