arguments.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """argparser configuration"""
  15. import argparse
  16. import os
  17. import deepspeed
  18. import json
  19. import torch
  20. from .utils import get_hostname
  21. def add_model_config_args(parser):
  22. """Model arguments"""
  23. group = parser.add_argument_group('model', 'model configuration')
  24. group.add_argument(
  25. '--transformer-xl',
  26. action='store_true',
  27. help='use transformer-xl for training')
  28. group.add_argument(
  29. '--pretrained-bert',
  30. action='store_true',
  31. help='use a pretrained bert-large-uncased model instead'
  32. 'of initializing from scratch. See '
  33. '--tokenizer-model-type to specify which pretrained '
  34. 'BERT model to use')
  35. group.add_argument(
  36. '--encoder-decoder',
  37. action='store_true',
  38. help='use the encoder-decoder architecture for blocklm')
  39. group.add_argument(
  40. '--attention-dropout',
  41. type=float,
  42. default=0.1,
  43. help='dropout probability for attention weights')
  44. group.add_argument(
  45. '--num-attention-heads',
  46. type=int,
  47. default=16,
  48. help='num of transformer attention heads')
  49. group.add_argument(
  50. '--hidden-size', type=int, default=1024, help='tansformer hidden size')
  51. group.add_argument(
  52. '--intermediate-size',
  53. type=int,
  54. default=None,
  55. help='transformer embedding dimension for FFN'
  56. 'set to 4*`--hidden-size` if it is None')
  57. group.add_argument(
  58. '--num-layers', type=int, default=24, help='num decoder layers')
  59. group.add_argument(
  60. '--layernorm-epsilon',
  61. type=float,
  62. default=1e-5,
  63. help='layer norm epsilon')
  64. group.add_argument(
  65. '--hidden-dropout',
  66. type=float,
  67. default=0.1,
  68. help='dropout probability for hidden state transformer')
  69. group.add_argument(
  70. '--output-dropout',
  71. type=float,
  72. default=0.1,
  73. help='dropout probability for pooled output')
  74. group.add_argument(
  75. '--max-position-embeddings',
  76. type=int,
  77. default=512,
  78. help='maximum number of position embeddings to use')
  79. group.add_argument(
  80. '--vocab-size',
  81. type=int,
  82. default=250112,
  83. help='vocab size to use for non-character-level '
  84. 'tokenization. This value will only be used when '
  85. 'creating a tokenizer')
  86. group.add_argument(
  87. '--deep-init',
  88. action='store_true',
  89. help='initialize bert model similar to gpt2 model.'
  90. 'scales initialization of projection layers by a '
  91. 'factor of 1/sqrt(2N). Necessary to train bert '
  92. 'models larger than BERT-Large.')
  93. group.add_argument(
  94. '--make-vocab-size-divisible-by',
  95. type=int,
  96. default=128,
  97. help='Pad the vocab size to be divisible by this value.'
  98. 'This is added for computational efficieny reasons.')
  99. group.add_argument(
  100. '--cpu-optimizer', action='store_true', help='Run optimizer on CPU')
  101. group.add_argument(
  102. '--cpu_torch_adam',
  103. action='store_true',
  104. help='Use Torch Adam as optimizer on CPU.')
  105. return parser
  106. def add_fp16_config_args(parser):
  107. """Mixed precision arguments."""
  108. group = parser.add_argument_group('fp16', 'fp16 configurations')
  109. group.add_argument(
  110. '--fp16', action='store_true', help='Run model in fp16 mode')
  111. group.add_argument(
  112. '--fp32-embedding', action='store_true', help='embedding in fp32')
  113. group.add_argument(
  114. '--fp32-layernorm', action='store_true', help='layer norm in fp32')
  115. group.add_argument(
  116. '--fp32-tokentypes',
  117. action='store_true',
  118. help='embedding token types in fp32')
  119. group.add_argument(
  120. '--fp32-allreduce', action='store_true', help='all-reduce in fp32')
  121. group.add_argument(
  122. '--hysteresis',
  123. type=int,
  124. default=2,
  125. help='hysteresis for dynamic loss scaling')
  126. group.add_argument(
  127. '--loss-scale',
  128. type=float,
  129. default=None,
  130. help='Static loss scaling, positive power of 2 '
  131. 'values can improve fp16 convergence. If None, dynamic'
  132. 'loss scaling is used.')
  133. group.add_argument(
  134. '--loss-scale-window',
  135. type=float,
  136. default=1000,
  137. help='Window over which to raise/lower dynamic scale')
  138. group.add_argument(
  139. '--min-scale',
  140. type=float,
  141. default=1,
  142. help='Minimum loss scale for dynamic loss scale')
  143. group.add_argument('--attention-scale', type=float, default=1.0)
  144. return parser
  145. def add_training_args(parser):
  146. """Training arguments."""
  147. group = parser.add_argument_group('train', 'training configurations')
  148. group.add_argument(
  149. '--experiment-name',
  150. type=str,
  151. default='gpt-345M',
  152. help='The experiment name for summary and checkpoint')
  153. group.add_argument(
  154. '--batch-size', type=int, default=4, help='Data Loader batch size')
  155. group.add_argument(
  156. '--gradient-accumulation-steps',
  157. type=int,
  158. default=1,
  159. help='Data Loader batch size')
  160. group.add_argument(
  161. '--weight-decay',
  162. type=float,
  163. default=0.01,
  164. help='weight decay coefficient for L2 regularization')
  165. group.add_argument(
  166. '--checkpoint-activations',
  167. action='store_true',
  168. help='checkpoint activation to allow for training '
  169. 'with larger models and sequences')
  170. group.add_argument(
  171. '--checkpoint-num-layers',
  172. type=int,
  173. default=1,
  174. help='chunk size (number of layers) for checkpointing')
  175. group.add_argument(
  176. '--deepspeed-activation-checkpointing',
  177. action='store_true',
  178. help='uses activation checkpointing from deepspeed')
  179. group.add_argument(
  180. '--epochs',
  181. type=int,
  182. default=None,
  183. help='Number of finetunning epochs. Zero results in evaluation only.')
  184. group.add_argument(
  185. '--clip-grad', type=float, default=1.0, help='gradient clipping')
  186. group.add_argument(
  187. '--train-iters',
  188. type=int,
  189. default=0,
  190. help='total number of iterations to train over all training runs')
  191. group.add_argument('--label-smoothing', type=float, default=0.0)
  192. group.add_argument(
  193. '--log-interval', type=int, default=100, help='report interval')
  194. group.add_argument(
  195. '--summary-dir',
  196. type=str,
  197. default='',
  198. help='The directory to store the summary')
  199. group.add_argument('--seed', type=int, default=1234, help='random seed')
  200. # Batch producer arguments
  201. group.add_argument(
  202. '--reset-position-ids',
  203. action='store_true',
  204. help='Reset position ids after end-of-document token.')
  205. group.add_argument(
  206. '--reset-attention-mask',
  207. action='store_true',
  208. help='Reset self attention maske after '
  209. 'end-of-document token.')
  210. # Learning rate.
  211. group.add_argument(
  212. '--lr-decay-iters',
  213. type=int,
  214. default=None,
  215. help='number of iterations to decay LR over,'
  216. ' If None defaults to `--train-iters`*`--epochs`')
  217. group.add_argument(
  218. '--lr-decay-style',
  219. type=str,
  220. default='linear',
  221. choices=['constant', 'linear', 'cosine', 'exponential'],
  222. help='learning rate decay function')
  223. group.add_argument('--lr-decay-ratio', type=float, default=0.1)
  224. group.add_argument(
  225. '--lr', type=float, default=1.0e-4, help='initial learning rate')
  226. group.add_argument(
  227. '--warmup',
  228. type=float,
  229. default=0.01,
  230. help='percentage of data to warmup on (.01 = 1% of all '
  231. 'training iters). Default 0.01')
  232. group.add_argument(
  233. '--switch-linear',
  234. action='store_true',
  235. help='Switch to linear decay for cosine decay')
  236. # model checkpointing
  237. group.add_argument(
  238. '--save',
  239. type=str,
  240. default=None,
  241. help='Output directory to save checkpoints to.')
  242. group.add_argument('--new-save-directory', action='store_true')
  243. group.add_argument(
  244. '--save-epoch',
  245. type=int,
  246. default=1,
  247. help='number of epochs between saves')
  248. group.add_argument(
  249. '--save-interval',
  250. type=int,
  251. default=5000,
  252. help='number of iterations between saves')
  253. group.add_argument(
  254. '--no-save-optim',
  255. action='store_true',
  256. help='Do not save current optimizer.')
  257. group.add_argument(
  258. '--no-save-rng',
  259. action='store_true',
  260. help='Do not save current rng state.')
  261. group.add_argument(
  262. '--load',
  263. type=str,
  264. default=None,
  265. help='Path to a directory containing a model checkpoint.')
  266. group.add_argument(
  267. '--no-load-optim',
  268. action='store_true',
  269. help='Do not load optimizer when loading checkpoint.')
  270. group.add_argument(
  271. '--no-load-rng',
  272. action='store_true',
  273. help='Do not load rng state when loading checkpoint.')
  274. group.add_argument(
  275. '--no-load-lr-scheduler',
  276. action='store_true',
  277. help='Do not load lr scheduler when loading checkpoint.')
  278. group.add_argument(
  279. '--no-deepspeed-load',
  280. action='store_true',
  281. help='Not use deepspeed when loading checkpoint')
  282. group.add_argument(
  283. '--finetune',
  284. action='store_true',
  285. help='Load model for finetuning. Do not load optimizer '
  286. 'or rng state from checkpoint and set iteration to 0. '
  287. 'Assumed when loading a release checkpoint.')
  288. group.add_argument(
  289. '--resume-dataloader',
  290. action='store_true',
  291. help='Resume the dataloader when resuming training. '
  292. 'Does not apply to tfrecords dataloader, try resuming'
  293. 'with a different seed in this case.')
  294. # distributed training args
  295. group.add_argument(
  296. '--distributed-backend',
  297. default='nccl',
  298. help=
  299. 'which backend to use for distributed training. One of [gloo, nccl]',
  300. choices=['nccl', 'gloo'])
  301. group.add_argument(
  302. '--DDP-impl',
  303. default='torch',
  304. choices=['local', 'torch', 'none'],
  305. help='which DistributedDataParallel implementation to use.')
  306. group.add_argument(
  307. '--local_rank',
  308. type=int,
  309. default=None,
  310. help='local rank passed from distributed launcher')
  311. # BlockLM training args
  312. group.add_argument(
  313. '--block-lm',
  314. action='store_true',
  315. help='whether use the BlockLM pre-training')
  316. group.add_argument(
  317. '--masked-lm',
  318. action='store_true',
  319. help='whether to use the mlm objective')
  320. group.add_argument('--bert-prob', type=float, default=0.5)
  321. group.add_argument('--gpt-infill-prob', type=float, default=0.5)
  322. group.add_argument('--gpt-min-ratio', type=float, default=0.5)
  323. group.add_argument('--gap-sentence-prob', type=float, default=0.0)
  324. group.add_argument('--gap-sentence-ratio', type=float, default=0.15)
  325. group.add_argument('--avg-block-length', type=int, default=3)
  326. group.add_argument('--short-seq-prob', type=float, default=0.0)
  327. group.add_argument('--single-span-prob', type=float, default=0.0)
  328. group.add_argument(
  329. '--task-mask',
  330. action='store_true',
  331. help='Use different mask for generation and blank filling')
  332. group.add_argument(
  333. '--no-shuffle-block',
  334. action='store_true',
  335. help='not shuffle the blocks when filling the blank')
  336. group.add_argument(
  337. '--no-block-position',
  338. action='store_true',
  339. help='Use (rough) absolute positions instead of block positions')
  340. group.add_argument(
  341. '--sentinel-token',
  342. action='store_true',
  343. help='Use sentinel (mask) tokens to replace 2d position encoding')
  344. group.add_argument('--block-mask-prob', type=float, default=0.0)
  345. group.add_argument('--context-mask-ratio', type=float, default=0.0)
  346. group.add_argument(
  347. '--random-position',
  348. action='store_true',
  349. help='Use random start position to cover all the position embeddings')
  350. return parser
  351. def add_evaluation_args(parser):
  352. """Evaluation arguments."""
  353. group = parser.add_argument_group('validation',
  354. 'validation configurations')
  355. group.add_argument(
  356. '--eval-batch-size',
  357. type=int,
  358. default=None,
  359. help='Data Loader batch size for evaluation datasets.'
  360. 'Defaults to `--batch-size`')
  361. group.add_argument(
  362. '--eval-iters',
  363. type=int,
  364. default=100,
  365. help='number of iterations to run for evaluation'
  366. 'validation/test for')
  367. group.add_argument(
  368. '--eval-interval',
  369. type=int,
  370. default=1000,
  371. help='interval between running evaluation on validation set')
  372. group.add_argument(
  373. '--eval-epoch',
  374. type=int,
  375. default=1,
  376. help='epoch between running evaluation on validation set')
  377. group.add_argument(
  378. '--eval-seq-length',
  379. type=int,
  380. default=None,
  381. help='Maximum sequence length to process for '
  382. 'evaluation. Defaults to `--seq-length`')
  383. group.add_argument(
  384. '--eval-max-preds-per-seq',
  385. type=int,
  386. default=None,
  387. help='Maximum number of predictions to use for '
  388. 'evaluation. Defaults to '
  389. 'math.ceil(`--eval-seq-length`*.15/10)*10')
  390. group.add_argument('--overlapping-eval', type=int, default=32)
  391. return parser
  392. def add_text_generate_args(parser):
  393. """Text generate arguments."""
  394. group = parser.add_argument_group('Text generation', 'configurations')
  395. group.add_argument('--temperature', type=float, default=1.0)
  396. group.add_argument('--top_p', type=float, default=0.0)
  397. group.add_argument('--top_k', type=int, default=0)
  398. group.add_argument('--out-seq-length', type=int, default=256)
  399. group.add_argument('--num-beams', type=int, default=1)
  400. group.add_argument('--length-penalty', type=float, default=0.0)
  401. group.add_argument('--no-repeat-ngram-size', type=int, default=0)
  402. group.add_argument('--min-tgt-length', type=int, default=0)
  403. group.add_argument('--select-topk', action='store_true')
  404. group.add_argument('--blank-maskratio', type=float, default=0.1)
  405. return parser
  406. def add_data_args(parser):
  407. """Train/valid/test data arguments."""
  408. group = parser.add_argument_group('data', 'data configurations')
  409. group.add_argument(
  410. '--model-parallel-size',
  411. type=int,
  412. default=1,
  413. help='size of the model parallel.')
  414. group.add_argument(
  415. '--shuffle',
  416. action='store_true',
  417. help='Shuffle data. Shuffling is deterministic '
  418. 'based on seed and current epoch.')
  419. group.add_argument('--filter-english', action='store_true')
  420. group.add_argument(
  421. '--train-data',
  422. nargs='+',
  423. default=None,
  424. help='Whitespace separated filenames or corpora names '
  425. 'for training.')
  426. group.add_argument(
  427. '--valid-data',
  428. nargs='*',
  429. default=None,
  430. help="""Filename for validation data.""")
  431. group.add_argument(
  432. '--test-data',
  433. nargs='*',
  434. default=None,
  435. help="""Filename for testing""")
  436. group.add_argument(
  437. '--data-dir',
  438. type=str,
  439. default=None,
  440. help='The data path to all the data files')
  441. group.add_argument(
  442. '--input-data-sizes-file',
  443. type=str,
  444. default='sizes.txt',
  445. help='the filename containing all the shards sizes')
  446. group.add_argument(
  447. '--delim', default=',', help='delimiter used to parse csv data files')
  448. group.add_argument(
  449. '--text-key',
  450. default='sentence',
  451. help='key to use to extract text from json/csv')
  452. group.add_argument(
  453. '--eval-text-key',
  454. default=None,
  455. help='key to use to extract text from '
  456. 'json/csv evaluation datasets')
  457. group.add_argument(
  458. '--split',
  459. default='1000,1,1',
  460. help='comma-separated list of proportions for training,'
  461. ' validation, and test split')
  462. group.add_argument(
  463. '--no-lazy-loader',
  464. action='store_true',
  465. help='whether to lazy read the data set')
  466. group.add_argument('--half-lazy-loader', action='store_true')
  467. group.add_argument(
  468. '--loader-scatter',
  469. type=int,
  470. default=None,
  471. help='Number of scatters to use for dataloaders')
  472. group.add_argument(
  473. '--loose-json',
  474. action='store_true',
  475. help='Use loose json (one json-formatted string per '
  476. 'newline), instead of tight json (data file is one '
  477. 'json string)')
  478. group.add_argument(
  479. '--presplit-sentences',
  480. action='store_true',
  481. help='Dataset content consists of documents where '
  482. 'each document consists of newline separated sentences')
  483. group.add_argument(
  484. '--num-workers',
  485. type=int,
  486. default=2,
  487. help="""Number of workers to use for dataloading""")
  488. group.add_argument(
  489. '--tokenizer-model-type',
  490. type=str,
  491. default=None,
  492. help="Model type to use for sentencepiece tokenization \
  493. (one of ['bpe', 'char', 'unigram', 'word']) or \
  494. bert vocab to use for BertWordPieceTokenizer (one of \
  495. ['bert-large-uncased', 'bert-large-cased', etc.])")
  496. group.add_argument(
  497. '--tokenizer-path',
  498. type=str,
  499. default='tokenizer.model',
  500. help='path used to save/load sentencepiece tokenization '
  501. 'models')
  502. group.add_argument(
  503. '--tokenizer-type',
  504. type=str,
  505. default='BertWordPieceTokenizer',
  506. choices=[
  507. 'CharacterLevelTokenizer', 'SentencePieceTokenizer',
  508. 'BertWordPieceTokenizer', 'GPT2BPETokenizer', 'ChineseSPTokenizer'
  509. ],
  510. help='what type of tokenizer to use')
  511. group.add_argument('--no-pre-tokenize', action='store_true')
  512. group.add_argument(
  513. '--cache-dir',
  514. default=None,
  515. type=str,
  516. help='Where to store pre-trained BERT downloads')
  517. group.add_argument(
  518. '--use-tfrecords',
  519. action='store_true',
  520. help='load `--train-data`, `--valid-data`, '
  521. '`--test-data` from BERT tf records instead of '
  522. 'normal data pipeline')
  523. group.add_argument(
  524. '--seq-length',
  525. type=int,
  526. default=512,
  527. help='Maximum sequence length to process')
  528. group.add_argument(
  529. '--mem-length',
  530. type=int,
  531. default=0,
  532. help='The memory length to preserve')
  533. group.add_argument(
  534. '--max-preds-per-seq',
  535. type=int,
  536. default=None,
  537. help='Maximum number of predictions to use per sequence.'
  538. 'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
  539. 'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
  540. group.add_argument('--non-sentence-start', type=float, default=0.0)
  541. group.add_argument(
  542. '--sample-one-document',
  543. action='store_true',
  544. help='only sample one document in one sample')
  545. group.add_argument(
  546. '--load-splits',
  547. type=str,
  548. default=None,
  549. help='The path to load split indices from')
  550. group.add_argument(
  551. '--save-splits',
  552. type=str,
  553. default=None,
  554. help='The path to save split indices to')
  555. group.add_argument(
  556. '--save-test-data',
  557. type=str,
  558. default=None,
  559. help='The path to save the test data')
  560. group.add_argument(
  561. '--multi-task-data',
  562. nargs='*',
  563. default=None,
  564. help='Downsteam task names for multi-task pre-training')
  565. group.add_argument(
  566. '--multi-task-ratio',
  567. type=float,
  568. default=0.0,
  569. help='Ratio for multi-task pre-training')
  570. group.add_argument('--multi-seq-length', type=int, default=None)
  571. group.add_argument('--multi-batch-size', type=int, default=None)
  572. return parser
  573. def add_finetune_config_args(parser):
  574. group = parser.add_argument_group('finetune', 'finetune configurations')
  575. group.add_argument('--task', type=str, help='Task name.')
  576. group.add_argument(
  577. '--load-pretrained',
  578. type=str,
  579. help='Load pretrained model',
  580. default=None)
  581. group.add_argument(
  582. '--pool-token',
  583. type=str,
  584. choices=['start', 'pad', 'cls'],
  585. help='The token to pool the sequence representation',
  586. default='cls')
  587. group.add_argument(
  588. '--cloze-eval',
  589. action='store_true',
  590. help='Evaluation dataset with cloze task')
  591. group.add_argument(
  592. '--multi-token',
  593. action='store_true',
  594. help='Use multi token for cloze evaluation')
  595. group.add_argument(
  596. '--segment-length',
  597. type=int,
  598. default=0,
  599. help='The maximum segment length for cloze evaluation')
  600. group.add_argument(
  601. '--loss-func',
  602. type=str,
  603. choices=['cross_entropy', 'hinge', 'generative', 'mix'],
  604. default='cross_entropy')
  605. group.add_argument('--block-lm-ratio', type=float, default=0.0)
  606. group.add_argument(
  607. '--adapet',
  608. action='store_true',
  609. help='Use the decoupled cross entropy loss in AdaPET')
  610. group.add_argument('--pattern-id', type=int, default=0)
  611. group.add_argument(
  612. '--fast-decode',
  613. action='store_true',
  614. help=
  615. 'Fast decode for multi-token cloze. Can only be used without checkpoint activation.'
  616. )
  617. group.add_argument('--few-superglue', action='store_true')
  618. group.add_argument(
  619. '--eval-valid',
  620. action='store_true',
  621. help='Whether evaluate on the valid set')
  622. group.add_argument('--validation-metric', type=str, default=None)
  623. group.add_argument(
  624. '--unidirectional',
  625. action='store_true',
  626. help='Use the left to right language model')
  627. group.add_argument('--src-seq-length', type=int, default=None)
  628. group.add_argument('--tgt-seq-length', type=int, default=None)
  629. group.add_argument('--adam-beta1', type=float, default=0.9)
  630. group.add_argument('--adam-beta2', type=float, default=0.999)
  631. group.add_argument('--adam-eps', type=float, default=1e-8)
  632. group.add_argument(
  633. '--optimizer', type=str, choices=['adam', 'adafactor'], default='adam')
  634. group.add_argument('--wsc-negative', action='store_true')
  635. group.add_argument('--overwrite', action='store_true')
  636. group.add_argument('--no-validation', action='store_true')
  637. # Continuous prompt arguments
  638. group.add_argument(
  639. '--continuous-prompt',
  640. action='store_true',
  641. help='Use continuous prompt for PET')
  642. group.add_argument('--num-prompt-tokens', type=int, default=0)
  643. group.add_argument(
  644. '--prompt-func', default='lstm', choices=['lstm', 'mlp', 'none'])
  645. group.add_argument(
  646. '--freeze-transformer', action='store_true', default=False)
  647. group.add_argument('--tune-prefix-layers', type=int, default=None)
  648. group.add_argument('--prefix-prompt', type=int, default=0)
  649. group.add_argument('--prompt-init', action='store_true', default=False)
  650. return parser
  651. def get_args():
  652. """Parse all the args."""
  653. parser = argparse.ArgumentParser(description='PyTorch BERT Model')
  654. parser = add_model_config_args(parser)
  655. parser = add_fp16_config_args(parser)
  656. parser = add_training_args(parser)
  657. parser = add_evaluation_args(parser)
  658. parser = add_text_generate_args(parser)
  659. parser = add_data_args(parser)
  660. parser = add_finetune_config_args(parser)
  661. # Include DeepSpeed configuration arguments
  662. parser = deepspeed.add_config_arguments(parser)
  663. args = parser.parse_args(args=[])
  664. if not args.train_data and not args.data_dir:
  665. print('WARNING: No training data specified')
  666. args.cuda = torch.cuda.is_available()
  667. args.rank = int(os.getenv('RANK', '0'))
  668. args.world_size = int(os.getenv('WORLD_SIZE', '1'))
  669. if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
  670. mpi_define_env(args)
  671. elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
  672. # We are using (OpenMPI) mpirun for launching distributed data parallel processes
  673. local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
  674. local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
  675. # Possibly running with Slurm
  676. num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
  677. nodeid = int(os.getenv('SLURM_NODEID', '0'))
  678. args.local_rank = local_rank
  679. args.rank = nodeid * local_size + local_rank
  680. args.world_size = num_nodes * local_size
  681. args.model_parallel_size = min(args.model_parallel_size, args.world_size)
  682. if args.rank == 0:
  683. print('using world size: {} and model-parallel size: {} '.format(
  684. args.world_size, args.model_parallel_size))
  685. args.dynamic_loss_scale = False
  686. if args.loss_scale is None:
  687. args.dynamic_loss_scale = True
  688. if args.rank == 0:
  689. print(' > using dynamic loss scaling')
  690. # The args fp32_* or fp16_* meant to be active when the
  691. # args fp16 is set. So the default behaviour should all
  692. # be false.
  693. if not args.fp16:
  694. args.fp32_embedding = False
  695. args.fp32_tokentypes = False
  696. args.fp32_layernorm = False
  697. if hasattr(args, 'deepspeed'
  698. ) and args.deepspeed and args.deepspeed_config is not None:
  699. with open(args.deepspeed_config, encoding='utf-8') as file:
  700. deepspeed_config = json.load(file)
  701. if 'train_micro_batch_size_per_gpu' in deepspeed_config:
  702. args.batch_size = deepspeed_config[
  703. 'train_micro_batch_size_per_gpu']
  704. if 'gradient_accumulation_steps' in deepspeed_config:
  705. args.gradient_accumulation_steps = deepspeed_config[
  706. 'gradient_accumulation_steps']
  707. else:
  708. args.gradient_accumulation_steps = 1
  709. if 'optimizer' in deepspeed_config:
  710. optimizer_params_config = deepspeed_config['optimizer'].get(
  711. 'params', {})
  712. args.lr = optimizer_params_config.get('lr', args.lr)
  713. args.weight_decay = optimizer_params_config.get(
  714. 'weight_decay', args.weight_decay)
  715. return args
  716. def mpi_define_env(args):
  717. from mpi4py import MPI
  718. comm = MPI.COMM_WORLD
  719. rank = comm.Get_rank()
  720. world_size = comm.Get_size()
  721. master_addr = None
  722. if rank == 0:
  723. master_addr = get_hostname()
  724. master_addr = comm.bcast(master_addr, root=0)
  725. # Determine local rank by assuming hostnames are unique
  726. proc_name = MPI.Get_processor_name()
  727. all_procs = comm.allgather(proc_name)
  728. local_rank = sum([i == proc_name for i in all_procs[:rank]])
  729. os.environ['RANK'] = str(rank)
  730. os.environ['WORLD_SIZE'] = str(world_size)
  731. args.local_rank = local_rank
  732. args.world_size = world_size
  733. args.rank = rank
  734. os.environ['MASTER_ADDR'] = master_addr
  735. os.environ[
  736. 'MASTER_PORT'] = '29500' # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
  737. print(
  738. 'Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}'
  739. .format(os.environ['RANK'], args.local_rank, os.environ['WORLD_SIZE'],
  740. os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']))