training_args.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import re
  3. from copy import deepcopy
  4. from dataclasses import dataclass, field, fields
  5. from typing import List, Union
  6. import addict
  7. import json
  8. from modelscope.trainers.cli_argument_parser import CliArgumentParser
  9. from modelscope.utils.config import Config
  10. from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE
  11. def set_flatten_value(values: Union[str, List[str]]):
  12. pairs = values.split(',') if isinstance(values, str) else values
  13. _params = {}
  14. for kv in pairs or []:
  15. if len(kv.strip()) == 0:
  16. continue
  17. key, value = kv.split('=')
  18. _params[key] = parse_value(value)
  19. return _params
  20. @dataclass
  21. class DatasetArgs:
  22. train_dataset_name: str = field(
  23. default=None,
  24. metadata={
  25. 'help':
  26. 'The dataset name used for training, can be an id in the datahub or a local dir',
  27. })
  28. val_dataset_name: str = field(
  29. default=None,
  30. metadata={
  31. 'help':
  32. 'The subset name used for evaluating, can be an id in the datahub or a local dir',
  33. })
  34. train_subset_name: str = field(
  35. default=None,
  36. metadata={
  37. 'help': 'The subset name used for training, can be None',
  38. })
  39. val_subset_name: str = field(
  40. default=None,
  41. metadata={
  42. 'help': 'The subset name used for evaluating, can be None',
  43. })
  44. train_split: str = field(
  45. default=None, metadata={
  46. 'help': 'The split of train dataset',
  47. })
  48. val_split: str = field(
  49. default=None, metadata={
  50. 'help': 'The split of val dataset',
  51. })
  52. train_dataset_namespace: str = field(
  53. default=DEFAULT_DATASET_NAMESPACE,
  54. metadata={
  55. 'help': 'The dataset namespace used for training',
  56. })
  57. val_dataset_namespace: str = field(
  58. default=DEFAULT_DATASET_NAMESPACE,
  59. metadata={
  60. 'help': 'The dataset namespace used for evaluating',
  61. })
  62. dataset_json_file: str = field(
  63. default=None,
  64. metadata={
  65. 'help':
  66. 'The json file to parse all datasets from, used in a complex dataset scenario,'
  67. 'the json format should be like:'
  68. '''
  69. [
  70. {
  71. "dataset": {
  72. # All args used in the MsDataset.load function
  73. "dataset_name": "xxx",
  74. ...
  75. },
  76. # All columns used, mapping the column names in each dataset in same names.
  77. "column_mapping": {
  78. "text1": "sequence1",
  79. "text2": "sequence2",
  80. "label": "label",
  81. },
  82. # float or str, float means to split the dataset into train/val,
  83. # or just str(train/val)
  84. "split": 0.8,
  85. }
  86. ]
  87. ''',
  88. })
  89. @dataclass
  90. class ModelArgs:
  91. task: str = field(
  92. default=None,
  93. metadata={
  94. 'help': 'The task code to be used',
  95. 'cfg_node': 'task'
  96. })
  97. model: str = field(
  98. default=None, metadata={
  99. 'help': 'A model id or model dir',
  100. })
  101. model_revision: str = field(
  102. default=None, metadata={
  103. 'help': 'the revision of model',
  104. })
  105. model_type: str = field(
  106. default=None,
  107. metadata={
  108. 'help':
  109. 'The mode type, if load_model_config is False, user need to fill this field',
  110. 'cfg_node': 'model.type'
  111. })
  112. @dataclass
  113. class TrainArgs:
  114. seed: int = field(
  115. default=42, metadata={
  116. 'help': 'The random seed',
  117. })
  118. per_device_train_batch_size: int = field(
  119. default=16,
  120. metadata={
  121. 'cfg_node': 'train.dataloader.batch_size_per_gpu',
  122. 'help':
  123. 'The `batch_size_per_gpu` argument for the train dataloader',
  124. })
  125. train_data_worker: int = field(
  126. default=0,
  127. metadata={
  128. 'cfg_node': 'train.dataloader.workers_per_gpu',
  129. 'help': 'The `workers_per_gpu` argument for the train dataloader',
  130. })
  131. train_shuffle: bool = field(
  132. default=False,
  133. metadata={
  134. 'cfg_node': 'train.dataloader.shuffle',
  135. 'help': 'The `shuffle` argument for the train dataloader',
  136. })
  137. train_drop_last: bool = field(
  138. default=False,
  139. metadata={
  140. 'cfg_node': 'train.dataloader.drop_last',
  141. 'help': 'The `drop_last` argument for the train dataloader',
  142. })
  143. per_device_eval_batch_size: int = field(
  144. default=16,
  145. metadata={
  146. 'cfg_node': 'evaluation.dataloader.batch_size_per_gpu',
  147. 'help':
  148. 'The `batch_size_per_gpu` argument for the eval dataloader',
  149. })
  150. eval_data_worker: int = field(
  151. default=0,
  152. metadata={
  153. 'cfg_node': 'evaluation.dataloader.workers_per_gpu',
  154. 'help': 'The `workers_per_gpu` argument for the eval dataloader',
  155. })
  156. eval_shuffle: bool = field(
  157. default=False,
  158. metadata={
  159. 'cfg_node': 'evaluation.dataloader.shuffle',
  160. 'help': 'The `shuffle` argument for the eval dataloader',
  161. })
  162. eval_drop_last: bool = field(
  163. default=False,
  164. metadata={
  165. 'cfg_node': 'evaluation.dataloader.drop_last',
  166. 'help': 'The `drop_last` argument for the eval dataloader',
  167. })
  168. max_epochs: int = field(
  169. default=5,
  170. metadata={
  171. 'cfg_node': 'train.max_epochs',
  172. 'help': 'The training epochs',
  173. })
  174. work_dir: str = field(
  175. default='./train_target',
  176. metadata={
  177. 'cfg_node': 'train.work_dir',
  178. 'help': 'The directory to save models and logs',
  179. })
  180. lr: float = field(
  181. default=5e-5,
  182. metadata={
  183. 'cfg_node': 'train.optimizer.lr',
  184. 'help': 'The learning rate of the optimizer',
  185. })
  186. lr_scheduler: str = field(
  187. default='LinearLR',
  188. metadata={
  189. 'cfg_node': 'train.lr_scheduler.type',
  190. 'help': 'The lr_scheduler type in torch',
  191. })
  192. optimizer: str = field(
  193. default='AdamW',
  194. metadata={
  195. 'cfg_node': 'train.optimizer.type',
  196. 'help': 'The optimizer type in PyTorch, like `AdamW`',
  197. })
  198. optimizer_params: str = field(
  199. default=None,
  200. metadata={
  201. 'cfg_node': 'train.optimizer',
  202. 'help': 'The optimizer params',
  203. 'cfg_setter': set_flatten_value,
  204. })
  205. lr_scheduler_params: str = field(
  206. default=None,
  207. metadata={
  208. 'cfg_node': 'train.lr_scheduler',
  209. 'help': 'The lr scheduler params',
  210. 'cfg_setter': set_flatten_value,
  211. })
  212. lr_strategy: str = field(
  213. default='by_epoch',
  214. metadata={
  215. 'cfg_node': 'train.lr_scheduler.options.lr_strategy',
  216. 'help': 'The lr decay strategy',
  217. 'choices': ['by_epoch', 'by_step', 'no'],
  218. })
  219. local_rank: int = field(
  220. default=0, metadata={
  221. 'help': 'The local rank',
  222. })
  223. logging_interval: int = field(
  224. default=5,
  225. metadata={
  226. 'help': 'The interval of iter of logging information',
  227. 'cfg_node': 'train.logging.interval',
  228. })
  229. eval_strategy: str = field(
  230. default='by_epoch',
  231. metadata={
  232. 'help': 'Eval strategy, can be `by_epoch` or `by_step` or `no`',
  233. 'cfg_node': 'evaluation.period.eval_strategy',
  234. 'choices': ['by_epoch', 'by_step', 'no'],
  235. })
  236. eval_interval: int = field(
  237. default=1,
  238. metadata={
  239. 'help': 'Eval interval',
  240. 'cfg_node': 'evaluation.period.interval',
  241. })
  242. eval_metrics: str = field(
  243. default=None,
  244. metadata={
  245. 'help': 'The metric name for evaluation',
  246. 'cfg_node': 'evaluation.metrics'
  247. })
  248. save_strategy: str = field(
  249. default='by_epoch',
  250. metadata={
  251. 'help':
  252. 'Checkpointing strategy, can be `by_epoch` or `by_step` or `no`',
  253. 'cfg_node': 'train.checkpoint.period.save_strategy',
  254. 'choices': ['by_epoch', 'by_step', 'no'],
  255. })
  256. save_interval: int = field(
  257. default=1,
  258. metadata={
  259. 'help':
  260. 'The interval of epoch or iter of saving checkpoint period',
  261. 'cfg_node': 'train.checkpoint.period.interval',
  262. })
  263. save_best_checkpoint: bool = field(
  264. default=False,
  265. metadata={
  266. 'help':
  267. 'Save the checkpoint(if it\'s the best) after the evaluation.',
  268. 'cfg_node': 'train.checkpoint.best.save_best',
  269. })
  270. metric_for_best_model: str = field(
  271. default=None,
  272. metadata={
  273. 'help': 'The metric used to measure the model.',
  274. 'cfg_node': 'train.checkpoint.best.metric_key',
  275. })
  276. metric_rule_for_best_model: str = field(
  277. default='max',
  278. metadata={
  279. 'help':
  280. 'The rule to measure the model with the metric, can be `max` or `min`',
  281. 'cfg_node': 'train.checkpoint.best.rule',
  282. })
  283. max_checkpoint_num: int = field(
  284. default=None,
  285. metadata={
  286. 'help':
  287. 'The max number of checkpoints to keep, older ones will be deleted.',
  288. 'cfg_node': 'train.checkpoint.period.max_checkpoint_num',
  289. })
  290. max_checkpoint_num_best: int = field(
  291. default=1,
  292. metadata={
  293. 'help':
  294. 'The max number of best checkpoints to keep, worse ones will be deleted.',
  295. 'cfg_node': 'train.checkpoint.best.max_checkpoint_num',
  296. })
  297. push_to_hub: bool = field(
  298. default=False,
  299. metadata={
  300. 'help': 'Push to hub after each checkpointing',
  301. 'cfg_node': 'train.checkpoint.period.push_to_hub',
  302. })
  303. repo_id: str = field(
  304. default=None,
  305. metadata={
  306. 'help':
  307. 'The repo id in modelhub, usually the format is "group/model"',
  308. 'cfg_node': 'train.checkpoint.period.hub_repo_id',
  309. })
  310. hub_token: str = field(
  311. default=None,
  312. metadata={
  313. 'help':
  314. 'The modelhub token, you can also set the token to the env variable `MODELSCOPE_API_TOKEN`',
  315. 'cfg_node': 'train.checkpoint.period.hub_token',
  316. })
  317. private_hub: bool = field(
  318. default=True,
  319. metadata={
  320. 'help': 'Upload to a private hub',
  321. 'cfg_node': 'train.checkpoint.period.private_hub',
  322. })
  323. hub_revision: str = field(
  324. default='master',
  325. metadata={
  326. 'help': 'Which branch to commit to',
  327. 'cfg_node': 'train.checkpoint.period.hub_revision',
  328. })
  329. push_to_hub_best: bool = field(
  330. default=False,
  331. metadata={
  332. 'help': 'Push to hub after each checkpointing',
  333. 'cfg_node': 'train.checkpoint.best.push_to_hub',
  334. })
  335. repo_id_best: str = field(
  336. default=None,
  337. metadata={
  338. 'help':
  339. 'The repo id in modelhub, usually the format is "group/model"',
  340. 'cfg_node': 'train.checkpoint.best.hub_repo_id',
  341. })
  342. hub_token_best: str = field(
  343. default=None,
  344. metadata={
  345. 'help':
  346. 'The modelhub token, you can also set the token to the env variable `MODELSCOPE_API_TOKEN`',
  347. 'cfg_node': 'train.checkpoint.best.hub_token',
  348. })
  349. private_hub_best: bool = field(
  350. default=True,
  351. metadata={
  352. 'help': 'Upload to a private hub',
  353. 'cfg_node': 'train.checkpoint.best.private_hub',
  354. })
  355. hub_revision_best: str = field(
  356. default='master',
  357. metadata={
  358. 'help': 'Which branch to commit to',
  359. 'cfg_node': 'train.checkpoint.best.hub_revision',
  360. })
  361. @dataclass(init=False)
  362. class TrainingArgs(DatasetArgs, TrainArgs, ModelArgs):
  363. use_model_config: bool = field(
  364. default=False,
  365. metadata={
  366. 'help':
  367. 'Use the configuration of the model, '
  368. 'default will only use the parameters in the CLI and the dataclass',
  369. })
  370. def __init__(self, **kwargs):
  371. self.manual_args = list(kwargs.keys())
  372. for f in fields(self):
  373. if f.name in kwargs:
  374. setattr(self, f.name, kwargs[f.name])
  375. self._unknown_args = {}
  376. def parse_cli(self, parser_args=None):
  377. """Construct a TrainingArg class by the parameters of CLI.
  378. Returns:
  379. Self
  380. """
  381. parser = CliArgumentParser(self)
  382. args, unknown = parser.parse_known_args(parser_args)
  383. unknown = [
  384. item for item in unknown
  385. if item not in ('\\', '\n') and '--local-rank=' not in item
  386. ]
  387. _unknown = {}
  388. for i in range(0, len(unknown), 2):
  389. _unknown[unknown[i].replace('-', '')] = parse_value(unknown[i + 1])
  390. args_dict = vars(args)
  391. self.manual_args += parser.manual_args
  392. self._unknown_args.update(_unknown)
  393. for key, value in deepcopy(args_dict).items():
  394. if key is not None and hasattr(self, key):
  395. setattr(self, key, value)
  396. return self
  397. def to_config(self, ignore_default_config=None):
  398. """Convert the TrainingArgs to the `Config`
  399. Returns:
  400. The Config, and extra parameters in dict.
  401. """
  402. cfg = Config()
  403. args_dict = addict.Dict()
  404. if ignore_default_config is None:
  405. ignore_default_config = self.use_model_config
  406. for f in fields(self):
  407. cfg_node = f.metadata.get('cfg_node')
  408. cfg_setter = f.metadata.get('cfg_setter') or (lambda x: x)
  409. if cfg_node is not None:
  410. if f.name in self.manual_args or not ignore_default_config:
  411. if isinstance(cfg_node, str):
  412. cfg_node = [cfg_node]
  413. for _node in cfg_node:
  414. cfg.merge_from_dict(
  415. {_node: cfg_setter(getattr(self, f.name))})
  416. else:
  417. args_dict[f.name] = getattr(self, f.name)
  418. cfg.merge_from_dict(self._unknown_args)
  419. return cfg, args_dict
  420. def get_metadata(self, key):
  421. _fields = fields(self)
  422. for f in _fields:
  423. if f.name == key:
  424. return f
  425. return None
  426. def build_dataset_from_file(filename):
  427. """
  428. The filename format:
  429. [
  430. {
  431. "dataset": {
  432. "dataset_name": "xxx",
  433. ...
  434. },
  435. "column_mapping": {
  436. "text1": "sequence1",
  437. "text2": "sequence2",
  438. "label": "label",
  439. }
  440. "usage": 0.8,
  441. }
  442. ]
  443. """
  444. from modelscope import MsDataset
  445. train_set = []
  446. eval_set = []
  447. with open(filename, 'r') as f:
  448. ds_json = json.load(f)
  449. for ds in ds_json:
  450. dataset = MsDataset.load(**ds['dataset']).to_hf_dataset()
  451. all_columns = dataset.column_names
  452. keep_columns = ds['column_mapping'].keys()
  453. remove_columns = [
  454. column for column in all_columns if column not in keep_columns
  455. ]
  456. from datasets import Features
  457. from datasets import Value
  458. from datasets import ClassLabel
  459. features = [
  460. f for f in dataset.features.items() if f[0] in keep_columns
  461. ]
  462. new_features = {}
  463. for f in features:
  464. if isinstance(f[1], ClassLabel):
  465. new_features[f[0]] = Value(f[1].dtype)
  466. else:
  467. new_features[f[0]] = f[1]
  468. new_features = Features(new_features)
  469. dataset = dataset.map(
  470. lambda x: x,
  471. remove_columns=remove_columns,
  472. features=new_features).rename_columns(ds['column_mapping'])
  473. usage = ds['usage']
  474. if isinstance(usage, str):
  475. assert usage in ('train', 'val')
  476. if usage == 'train':
  477. train_set.append(dataset)
  478. else:
  479. eval_set.append(dataset)
  480. else:
  481. assert isinstance(usage, float) and 0 < usage < 1
  482. ds_dict = dataset.train_test_split(train_size=usage)
  483. train_set.append(ds_dict['train'])
  484. eval_set.append(ds_dict['test'])
  485. from datasets import concatenate_datasets
  486. return concatenate_datasets(train_set), concatenate_datasets(eval_set)
  487. def parse_value(value: str) -> Union[str, float, bool, None]:
  488. const_map = {
  489. 'True': True,
  490. 'true': True,
  491. 'False': False,
  492. 'false': False,
  493. 'None': None,
  494. 'none': None,
  495. 'null': None
  496. }
  497. if value in const_map:
  498. return const_map[value]
  499. elif '"' in value or "'" in value:
  500. return value.replace('"', '').replace("'", '')
  501. elif re.match(r'^\d+$', value):
  502. return int(value)
  503. elif re.match(r'[+-]?(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?',
  504. value):
  505. return float(value)
  506. else:
  507. return value