configure_data.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """parses arguments and preps data loader"""
  15. import copy
  16. import os
  17. import random
  18. from bisect import bisect_right
  19. from itertools import accumulate
  20. import numpy as np
  21. import torch
  22. import torch.utils.data
  23. from megatron_util import mpu, print_rank_0
  24. from . import data_utils
  25. from .blocklm_utils import ConstructBlockStrategy
  26. from .data_utils.tokenization import make_tokenizer
  27. class MultiTaskDataset(torch.utils.data.Dataset):
  28. def __init__(self,
  29. tasks,
  30. datasets,
  31. reweight=True,
  32. temperature=0.8,
  33. max_limit=200000):
  34. super(MultiTaskDataset, self).__init__()
  35. self.tasks = tasks
  36. self.datasets = datasets
  37. self.reweight = reweight
  38. self.temperature = temperature
  39. self.lens = [len(dataset) for dataset in datasets]
  40. self.weights = np.array(
  41. [min(length, max_limit)**temperature for length in self.lens])
  42. self.total_len = sum(self.lens)
  43. self.cumulative_lens = list(accumulate(self.lens))
  44. if self.reweight:
  45. print_rank_0(list(zip(self.tasks, self.lens, self.weights)))
  46. else:
  47. print_rank_0(list(zip(self.tasks, self.lens)))
  48. self.weights /= self.weights.sum()
  49. def __len__(self):
  50. return self.total_len * 1000
  51. @staticmethod
  52. def pet_wrapper(data):
  53. text = data['text']
  54. loss_mask = data['logit_mask']
  55. target = data['target']
  56. attention_mask = data['mask']
  57. position_id = data['position']
  58. label = data['label']
  59. if len(text.shape) == 2:
  60. text = text[label]
  61. loss_mask = loss_mask[label]
  62. target = target[label]
  63. attention_mask = attention_mask[label]
  64. position_id = position_id[label]
  65. else:
  66. target = target[label]
  67. if not target.shape:
  68. target = target.repeat(len(text))
  69. return {
  70. 'text': text,
  71. 'target': target,
  72. 'loss_mask': loss_mask,
  73. 'position_id': position_id,
  74. 'attention_mask': attention_mask
  75. }
  76. def __getitem__(self, idx):
  77. if self.reweight:
  78. rng = random.Random(idx)
  79. rng = np.random.RandomState(
  80. seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
  81. dataset_idx = rng.choice(
  82. np.arange(len(self.datasets)), p=self.weights)
  83. dataset = self.datasets[dataset_idx]
  84. sample_idx = rng.choice(np.arange(len(dataset)))
  85. item = self.datasets[dataset_idx][sample_idx]
  86. else:
  87. dataset_idx = bisect_right(self.cumulative_lens, idx)
  88. if dataset_idx == 0:
  89. sample_idx = idx
  90. else:
  91. sample_idx = idx - self.cumulative_lens[dataset_idx - 1]
  92. item = self.datasets[dataset_idx][sample_idx]
  93. item = self.pet_wrapper(item)
  94. return item
  95. class DataConfig:
  96. def __init__(self, defaults=None):
  97. super(DataConfig, self).__init__()
  98. if defaults is None:
  99. defaults = {}
  100. self.defaults = defaults
  101. def apply(self, args, tokenizer):
  102. if torch.distributed.get_rank() == 0:
  103. print('configuring data')
  104. self.apply_defaults(args)
  105. return make_loaders(args, tokenizer)
  106. def set_defaults(self, **kwargs):
  107. for k, v in kwargs.items():
  108. self.defaults[k] = v
  109. def apply_defaults(self, args):
  110. for k, v in self.defaults.items():
  111. k = k.replace('-', '_')
  112. if not hasattr(args, k):
  113. setattr(args, k, v)
  114. def prepare_tokenizer(args):
  115. add_sentinel_token = 0
  116. if args.sentinel_token:
  117. add_sentinel_token = args.max_position_embeddings
  118. tokenizer = make_tokenizer(
  119. args.tokenizer_type,
  120. None,
  121. args.tokenizer_path,
  122. args.vocab_size,
  123. args.tokenizer_model_type,
  124. add_block_symbols=args.block_lm,
  125. cache_dir=args.cache_dir,
  126. add_sentinel_token=add_sentinel_token,
  127. add_task_mask=args.task_mask,
  128. add_decoder_mask=args.block_mask_prob > 0.0
  129. or args.context_mask_ratio > 0.0)
  130. if mpu.get_model_parallel_rank() == 0:
  131. num_tokens = tokenizer.num_tokens
  132. eod_token = tokenizer.get_command('eos').Id
  133. assert eod_token == tokenizer.get_command('pad').Id
  134. before = num_tokens
  135. after = before
  136. multiple = args.make_vocab_size_divisible_by
  137. while (after % multiple) != 0:
  138. after += 1
  139. print_rank_0('> padded vocab (size: {}) with {} dummy '
  140. 'tokens (new size: {})'.format(before, after - before,
  141. after))
  142. print_rank_0('> found end-of-document token: {}'.format(eod_token))
  143. token_counts = torch.cuda.LongTensor([after, eod_token])
  144. else:
  145. token_counts = torch.cuda.LongTensor([0, 0])
  146. # Broadcast num tokens.
  147. torch.distributed.broadcast(
  148. token_counts,
  149. mpu.get_model_parallel_src_rank(),
  150. group=mpu.get_model_parallel_group())
  151. num_tokens = token_counts[0].item()
  152. eod_token = token_counts[1].item()
  153. args.vocab_size, args.eod_token = num_tokens, eod_token
  154. return tokenizer
  155. def make_data_loader(dataset,
  156. tokenizer,
  157. batch_size,
  158. num_iters,
  159. args,
  160. shuffle=False,
  161. block_collate=False):
  162. world_size = torch.distributed.get_world_size(
  163. group=mpu.get_data_parallel_group())
  164. rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
  165. if args.loader_scatter is not None:
  166. rank = rank // args.loader_scatter
  167. world_size = world_size // args.loader_scatter
  168. batch_size = batch_size // args.loader_scatter
  169. distributed = world_size > 1
  170. if args.transformer_xl:
  171. batch_sampler = data_utils.samplers.DistributedSequentialSampler(
  172. len(dataset), num_iters, batch_size, rank, world_size)
  173. else:
  174. if shuffle:
  175. sampler = data_utils.samplers.RandomSampler(
  176. dataset,
  177. replacement=True,
  178. num_samples=batch_size * args.train_iters
  179. * args.gradient_accumulation_steps)
  180. else:
  181. sampler = torch.utils.data.SequentialSampler(dataset)
  182. drop_last = distributed
  183. # the GPUs in the same model parallel group receive the same data
  184. if distributed:
  185. batch_sampler = data_utils.samplers.DistributedBatchSampler(
  186. sampler,
  187. batch_size,
  188. drop_last,
  189. rank,
  190. world_size,
  191. gradient_accumulation_steps=args.gradient_accumulation_steps)
  192. else:
  193. batch_sampler = torch.utils.data.BatchSampler(
  194. sampler, batch_size, drop_last)
  195. collate_fn = None
  196. if block_collate:
  197. collate_fn = ConstructBlockStrategy(
  198. args,
  199. tokenizer,
  200. args.seq_length,
  201. bert_prob=args.bert_prob,
  202. gap_sentence_prob=args.gap_sentence_prob,
  203. gap_sentence_ratio=args.gap_sentence_ratio,
  204. gpt_infill_prob=args.gpt_infill_prob,
  205. average_block_length=args.avg_block_length,
  206. gpt_min_ratio=args.gpt_min_ratio,
  207. block_mask_prob=args.block_mask_prob,
  208. context_mask_ratio=args.context_mask_ratio,
  209. short_seq_prob=args.short_seq_prob,
  210. single_span_prob=args.single_span_prob,
  211. shuffle_blocks=not args.no_shuffle_block,
  212. block_position_encoding=not args.no_block_position,
  213. sentinel_token=args.sentinel_token,
  214. encoder_decoder=args.encoder_decoder,
  215. task_mask=args.task_mask,
  216. random_position=args.random_position,
  217. masked_lm=args.masked_lm).construct_blocks
  218. data_loader = torch.utils.data.DataLoader(
  219. dataset,
  220. batch_sampler=batch_sampler,
  221. num_workers=args.num_workers,
  222. pin_memory=True,
  223. collate_fn=collate_fn)
  224. return data_loader
  225. def make_tfrecord_loaders(args):
  226. """Load train/val/test dataset from shuffled TFRecords"""
  227. import data_utils.tf_dl
  228. data_set_args = {
  229. 'batch_size': args.batch_size,
  230. 'max_seq_len': args.seq_length,
  231. 'max_preds_per_seq': args.max_preds_per_seq,
  232. 'train': True,
  233. 'num_workers': max(args.num_workers, 1),
  234. 'seed': args.seed + args.rank + 1,
  235. 'threaded_dl': args.num_workers > 0
  236. }
  237. train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
  238. **data_set_args)
  239. data_set_args['train'] = False
  240. if args.eval_seq_length is not None:
  241. data_set_args['max_seq_len'] = args.eval_seq_length
  242. if args.eval_max_preds_per_seq is not None:
  243. data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
  244. valid = None
  245. if args.valid_data is not None:
  246. valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data,
  247. **data_set_args)
  248. test = None
  249. if args.test_data is not None:
  250. test = data_utils.tf_dl.TFRecordDataLoader(args.test_data,
  251. **data_set_args)
  252. tokenizer = data_utils.make_tokenizer(
  253. args.tokenizer_type,
  254. train,
  255. args.tokenizer_path,
  256. args.vocab_size,
  257. args.tokenizer_model_type,
  258. cache_dir=args.cache_dir)
  259. return (train, valid, test), tokenizer
  260. def make_loaders(args, tokenizer):
  261. """makes training/val/test"""
  262. if args.use_tfrecords:
  263. return make_tfrecord_loaders(args)
  264. world_size = torch.distributed.get_world_size(
  265. group=mpu.get_data_parallel_group())
  266. if args.loader_scatter is not None:
  267. assert world_size % args.loader_scatter == 0
  268. batch_size = args.batch_size * world_size
  269. eval_batch_size = batch_size
  270. if args.eval_batch_size is not None:
  271. eval_batch_size = args.eval_batch_size * world_size
  272. seq_length = args.seq_length
  273. if seq_length < 0:
  274. seq_length = seq_length * world_size
  275. eval_seq_length = args.eval_seq_length
  276. if eval_seq_length is not None and eval_seq_length < 0:
  277. eval_seq_length = eval_seq_length * world_size
  278. split = get_split(args)
  279. data_set_args = {
  280. 'path': args.train_data,
  281. 'seq_length': seq_length,
  282. 'mem_length': args.mem_length,
  283. 'delim': args.delim,
  284. 'text_key': args.text_key,
  285. 'label_key': 'label',
  286. 'ds_type': args.data_set_type,
  287. 'split': split,
  288. 'loose': args.loose_json,
  289. 'max_preds_per_seq': args.max_preds_per_seq,
  290. 'presplit_sentences': args.presplit_sentences,
  291. 'sample_one_document': args.sample_one_document,
  292. 'filter_english': args.filter_english,
  293. 'pre_tokenize': not args.no_pre_tokenize,
  294. 'tokenizer': tokenizer,
  295. 'save_splits': args.save_splits,
  296. 'load_splits': args.load_splits,
  297. 'save_test_data': args.save_test_data,
  298. 'no_lazy_loader': args.no_lazy_loader,
  299. 'loader_scatter': args.loader_scatter,
  300. 'data_parallel_rank': mpu.get_data_parallel_rank(),
  301. 'non_sentence_start': args.non_sentence_start,
  302. 'half_lazy_loader': args.half_lazy_loader
  303. }
  304. eval_set_args = copy.copy(data_set_args)
  305. eval_set_args['split'] = [1.]
  306. # if optional eval args were set then replace their
  307. # equivalent values in the arg dict
  308. if eval_seq_length:
  309. eval_set_args['seq_length'] = eval_seq_length
  310. if args.eval_max_preds_per_seq:
  311. eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
  312. if args.eval_text_key is not None:
  313. eval_set_args['text_key'] = args.eval_text_key
  314. # make datasets splits and tokenizer
  315. train, valid, test = None, None, None
  316. if args.train_data is not None:
  317. train = data_utils.make_dataset(**data_set_args)
  318. if data_utils.should_split(split):
  319. train, valid, test = train
  320. eval_set_args['tokenizer'] = tokenizer
  321. # make training and val dataset if necessary
  322. if valid is None and args.valid_data is not None:
  323. eval_set_args['path'] = args.valid_data
  324. valid = data_utils.make_dataset(**eval_set_args)
  325. eval_set_args['tokenizer'] = tokenizer
  326. if test is None and args.test_data is not None:
  327. eval_set_args['path'] = args.test_data
  328. test = data_utils.make_dataset(**eval_set_args)
  329. # wrap datasets with data loader
  330. use_block = args.block_lm or args.encoder_decoder
  331. if train is not None and args.batch_size > 0:
  332. train = make_data_loader(
  333. train,
  334. tokenizer,
  335. batch_size,
  336. args.train_iters,
  337. args,
  338. shuffle=args.shuffle,
  339. block_collate=use_block)
  340. args.do_train = True
  341. else:
  342. args.do_train = False
  343. eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
  344. if valid is not None:
  345. valid = make_data_loader(
  346. valid,
  347. tokenizer,
  348. eval_batch_size,
  349. args.train_iters,
  350. args,
  351. shuffle=args.shuffle,
  352. block_collate=use_block)
  353. args.do_valid = True
  354. else:
  355. args.do_valid = False
  356. if test is not None:
  357. test = make_data_loader(
  358. test,
  359. tokenizer,
  360. eval_batch_size,
  361. len(test) // eval_batch_size + 1,
  362. args,
  363. shuffle=args.shuffle,
  364. block_collate=use_block)
  365. args.do_test = True
  366. else:
  367. args.do_test = False
  368. return train, valid, test
  369. def build_multi_task_dataset(args, tokenizer):
  370. task_dirs = {
  371. 'mnli': 'MNLI',
  372. 'cola': 'CoLA',
  373. 'mrpc': 'MRPC',
  374. 'qnli': 'QNLI',
  375. 'qqp': 'QQP',
  376. 'sst2': 'SST-2',
  377. 'agnews': 'Agnews',
  378. 'yelp-polarity': 'yelp_review_polarity_csv',
  379. 'yelp-full': 'yelp_review_full_csv',
  380. 'yahoo': 'Yahoo',
  381. 'squad': 'SQuAD',
  382. 'race': 'RACE'
  383. }
  384. train, valid = None, None
  385. if mpu.get_model_parallel_rank() == 0:
  386. multi_seq_length = args.seq_length
  387. if args.multi_seq_length is not None:
  388. multi_seq_length = args.multi_seq_length
  389. train_datasets, valid_datasets = [], []
  390. for task in args.multi_task_data:
  391. task = task.lower()
  392. data_dir = os.path.join(args.data_dir, task_dirs[task])
  393. train_datasets.append(
  394. SuperGlueDataset(
  395. args,
  396. task,
  397. data_dir,
  398. multi_seq_length,
  399. 'train',
  400. tokenizer,
  401. pattern_ensemble=True))
  402. valid_datasets.append(
  403. SuperGlueDataset(
  404. args,
  405. task,
  406. data_dir,
  407. multi_seq_length,
  408. 'dev',
  409. tokenizer,
  410. pattern_ensemble=True))
  411. train = MultiTaskDataset(args.multi_task_data, train_datasets)
  412. valid = MultiTaskDataset(args.multi_task_data, valid_datasets)
  413. world_size = torch.distributed.get_world_size(
  414. group=mpu.get_data_parallel_group())
  415. multi_batch_size = args.batch_size * world_size
  416. if args.multi_batch_size is not None:
  417. multi_batch_size = args.multi_batch_size * world_size
  418. train = make_data_loader(
  419. train,
  420. tokenizer,
  421. multi_batch_size,
  422. args.train_iters,
  423. args,
  424. shuffle=True)
  425. valid = make_data_loader(
  426. valid,
  427. tokenizer,
  428. multi_batch_size,
  429. args.train_iters,
  430. args,
  431. shuffle=True)
  432. return train, valid
  433. def get_split(args):
  434. """
  435. Get dataset splits from comma separated string list
  436. """
  437. splits = []
  438. if args.split.find(',') != -1:
  439. splits = [float(s) for s in args.split.split(',')]
  440. elif args.split.find('/') != -1:
  441. splits = [float(s) for s in args.split.split('/')]
  442. else:
  443. splits = [float(args.split)]
  444. split_total = sum(splits)
  445. if split_total < 1.:
  446. splits.append(1 - split_total)
  447. while len(splits) < 3:
  448. splits.append(0.)
  449. splits = splits[:3]
  450. if args.valid_data is not None:
  451. splits[1] = 0.
  452. if args.test_data is not None:
  453. splits[2] = 0.
  454. final_sum = sum(splits)
  455. return [s / final_sum for s in splits]
  456. def configure_data():
  457. """add cmdline flags for configuring datasets"""
  458. # These are options that are used by data_utils, but are either
  459. # deprecated or not meant to be exposed to the command line user.
  460. # These options are intneded to be set in code by specific scripts.
  461. defaults = {
  462. 'world_size': 1,
  463. 'rank': -1,
  464. 'persist_state': 0,
  465. 'lazy': False,
  466. 'transpose': False,
  467. 'data_set_type': 'supervised',
  468. 'seq_length': 256,
  469. 'eval_seq_length': 256,
  470. 'samples_per_shard': 100
  471. }
  472. return DataConfig(defaults=defaults)