megatron_lm.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import math
  16. import os
  17. from abc import ABC
  18. from functools import partial
  19. import torch
  20. import torch.nn.functional as F
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
  23. from ..optimizer import AcceleratedOptimizer
  24. from ..scheduler import AcceleratedScheduler
  25. from .imports import is_megatron_lm_available
  26. from .operations import recursively_apply, send_to_device
  27. if is_megatron_lm_available():
  28. from megatron.core import mpu, tensor_parallel
  29. from megatron.core.distributed import DistributedDataParallel as LocalDDP
  30. from megatron.core.distributed import finalize_model_grads
  31. from megatron.core.enums import ModelType
  32. from megatron.core.num_microbatches_calculator import get_num_microbatches
  33. from megatron.core.optimizer import get_megatron_optimizer
  34. from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank
  35. from megatron.core.pipeline_parallel import get_forward_backward_func
  36. from megatron.core.utils import get_model_config
  37. from megatron.inference.text_generation.communication import broadcast_int_list, broadcast_tensor
  38. from megatron.inference.text_generation.generation import (
  39. beam_search_and_return_on_first_stage,
  40. generate_tokens_probs_and_return_on_first_stage,
  41. )
  42. from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets
  43. from megatron.legacy.model import BertModel, Float16Module, GPTModel, T5Model
  44. from megatron.legacy.model.classification import Classification
  45. from megatron.training import (
  46. get_args,
  47. get_tensorboard_writer,
  48. get_tokenizer,
  49. print_rank_last,
  50. )
  51. from megatron.training.arguments import (
  52. _add_data_args,
  53. _add_validation_args,
  54. core_transformer_config_from_args,
  55. parse_args,
  56. validate_args,
  57. )
  58. from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint
  59. from megatron.training.global_vars import set_global_variables
  60. from megatron.training.initialize import (
  61. _compile_dependencies,
  62. _init_autoresume,
  63. _initialize_distributed,
  64. _set_random_seed,
  65. set_jit_fusion_options,
  66. write_args_to_tensorboard,
  67. )
  68. from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
  69. from megatron.training.training import (
  70. build_train_valid_test_data_iterators,
  71. get_optimizer_param_scheduler,
  72. num_floating_point_operations,
  73. setup_model_and_optimizer,
  74. train_step,
  75. training_log,
  76. )
  77. from megatron.training.utils import (
  78. average_losses_across_data_parallel_group,
  79. calc_params_l2_norm,
  80. get_ltor_masks_and_position_ids,
  81. unwrap_model,
  82. )
  83. # model utilities
  84. def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True):
  85. """Build the model."""
  86. args = get_args()
  87. mode = "pre-training" if args.pretraining_flag else "fine-tuning"
  88. if args.rank == 0:
  89. print(f"Building {args.model_type_name} model in the {mode} mode.")
  90. print(
  91. "The Megatron LM model weights are initialized at random in `accelerator.prepare`. "
  92. "Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup."
  93. )
  94. config = core_transformer_config_from_args(args)
  95. if args.model_type_name == "bert":
  96. if args.pretraining_flag:
  97. num_tokentypes = 2 if args.bert_binary_head else 0
  98. model = BertModel(
  99. config=config,
  100. num_tokentypes=num_tokentypes,
  101. add_binary_head=args.bert_binary_head,
  102. parallel_output=True,
  103. pre_process=pre_process,
  104. post_process=post_process,
  105. )
  106. else:
  107. model = Classification(
  108. config=config,
  109. num_classes=args.num_labels,
  110. num_tokentypes=2,
  111. pre_process=pre_process,
  112. post_process=post_process,
  113. )
  114. elif args.model_type_name == "gpt":
  115. model = GPTModel(
  116. config=config,
  117. num_tokentypes=0,
  118. parallel_output=True,
  119. pre_process=pre_process,
  120. post_process=post_process,
  121. )
  122. elif args.model_type_name == "t5":
  123. model = T5Model(
  124. config=config,
  125. num_tokentypes=0,
  126. parallel_output=True,
  127. pre_process=pre_process,
  128. post_process=post_process,
  129. add_encoder=add_encoder,
  130. add_decoder=add_decoder,
  131. )
  132. else:
  133. raise ValueError(f"Unsupported model type: {args.model_type_name}")
  134. return model
  135. def prepare_model_optimizer_scheduler(accelerator):
  136. accelerator.print("Preparing model optimizer scheduler")
  137. args = get_args()
  138. if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None:
  139. if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None:
  140. raise ValueError(
  141. "You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`."
  142. )
  143. custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function
  144. model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func)
  145. optimizer = prepare_optimizer(accelerator, model)
  146. scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None)
  147. else:
  148. model_type = ModelType.encoder_or_decoder
  149. if args.model_type_name == "t5":
  150. model_type = ModelType.encoder_and_decoder
  151. model_provider_func_ = model_provider_func
  152. if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None:
  153. model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function
  154. (model, optimizer, scheduler) = setup_model_and_optimizer(
  155. model_provider_func_,
  156. model_type,
  157. no_wd_decay_cond=args.no_wd_decay_cond,
  158. scale_lr_cond=args.scale_lr_cond,
  159. lr_mult=args.lr_mult,
  160. )
  161. args.model_len = len(model)
  162. return model, optimizer, scheduler
  163. # dataloader utilities
  164. class MegatronLMDummyDataLoader:
  165. """
  166. Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training
  167. Args:
  168. **dataset_kwargs: Megatron data arguments.
  169. """
  170. def __init__(self, **dataset_kwargs):
  171. parser = argparse.ArgumentParser()
  172. parser = _add_data_args(parser)
  173. parser = _add_validation_args(parser)
  174. data_args = parser.parse_known_args()
  175. self.dataset_args = vars(data_args[0])
  176. self.dataset_args.update(dataset_kwargs)
  177. self.dataset_args["megatron_dataset_flag"] = True
  178. def set_megatron_data_args(self):
  179. args = get_args()
  180. for key, value in self.dataset_args.items():
  181. old_value = getattr(args, key, "")
  182. if old_value != value:
  183. print(
  184. f"WARNING: MegatronLMDummyDataLoader overriding arguments for {key}:{old_value} with {key}:{value}"
  185. )
  186. setattr(args, key, value)
  187. def get_train_valid_test_datasets_provider(self, accelerator):
  188. def train_valid_test_datasets_provider(train_val_test_num_samples):
  189. """Build train, valid, and test datasets."""
  190. args = get_args()
  191. dataset_args = {
  192. "data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path],
  193. "splits_string": args.split,
  194. "train_valid_test_num_samples": train_val_test_num_samples,
  195. "seed": args.seed,
  196. }
  197. if args.model_type_name == "bert":
  198. dataset_args.update(
  199. {
  200. "max_seq_length": args.seq_length,
  201. "binary_head": args.bert_binary_head,
  202. }
  203. )
  204. elif args.model_type_name == "gpt":
  205. dataset_args.update(
  206. {
  207. "max_seq_length": args.seq_length,
  208. }
  209. )
  210. elif args.model_type_name == "t5":
  211. dataset_args.update(
  212. {
  213. "max_seq_length": args.encoder_seq_length,
  214. "max_seq_length_dec": args.decoder_seq_length,
  215. "dataset_type": "t5",
  216. }
  217. )
  218. else:
  219. raise ValueError(f"Unsupported model type: {args.model_type_name}")
  220. train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args)
  221. return train_ds, valid_ds, test_ds
  222. if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None:
  223. return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function
  224. try:
  225. args = get_args()
  226. # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
  227. if args.model_type_name == "bert":
  228. from pretrain_bert import train_valid_test_datasets_provider
  229. train_valid_test_datasets_provider.is_distributed = True
  230. return train_valid_test_datasets_provider
  231. elif args.model_type_name == "gpt":
  232. from pretrain_gpt import train_valid_test_datasets_provider
  233. train_valid_test_datasets_provider.is_distributed = True
  234. return train_valid_test_datasets_provider
  235. elif args.model_type_name == "t5":
  236. from pretrain_t5 import train_valid_test_datasets_provider
  237. train_valid_test_datasets_provider.is_distributed = True
  238. return train_valid_test_datasets_provider
  239. except ImportError:
  240. pass
  241. return train_valid_test_datasets_provider
  242. def build_train_valid_test_data_iterators(self, accelerator):
  243. args = get_args()
  244. train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator)
  245. if args.virtual_pipeline_model_parallel_size is not None:
  246. train_data_iterator = []
  247. valid_data_iterator = []
  248. test_data_iterator = []
  249. for i in range(getattr(args, "model_len", 0)):
  250. mpu.set_virtual_pipeline_model_parallel_rank(i)
  251. iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
  252. train_data_iterator.append(iterators[0])
  253. valid_data_iterator.append(iterators[1])
  254. test_data_iterator.append(iterators[2])
  255. else:
  256. train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators(
  257. train_valid_test_dataset_provider
  258. )
  259. return train_data_iterator, valid_data_iterator, test_data_iterator
  260. def _handle_megatron_data_iterator(accelerator, data_iterator):
  261. class DummyMegatronDataloader:
  262. def __iter__(self):
  263. return self
  264. def __next__(self):
  265. return {}
  266. is_data_iterator_empty = data_iterator is None
  267. is_src_data_iterator_empty = torch.tensor(is_data_iterator_empty, dtype=torch.bool, device=accelerator.device)
  268. torch.distributed.broadcast(
  269. is_src_data_iterator_empty, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
  270. )
  271. if not is_src_data_iterator_empty and is_data_iterator_empty:
  272. return DummyMegatronDataloader()
  273. return data_iterator
  274. def prepare_data_loader(accelerator, dataloader):
  275. accelerator.print("Preparing dataloader")
  276. args = get_args()
  277. if not args.megatron_dataset_flag:
  278. from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader
  279. micro_batch_size = args.micro_batch_size * args.num_micro_batches
  280. kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}
  281. if kwargs["batch_size"] is None:
  282. if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler):
  283. kwargs["sampler"].batch_size = micro_batch_size
  284. else:
  285. del kwargs["sampler"]
  286. del kwargs["shuffle"]
  287. del kwargs["batch_size"]
  288. kwargs["batch_sampler"].batch_size = micro_batch_size
  289. else:
  290. del kwargs["batch_sampler"]
  291. kwargs["batch_size"] = micro_batch_size
  292. dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs)
  293. # split_batches:
  294. # Megatron only needs to fetch different data between different dp groups,
  295. # and does not need to split the data within the dp group.
  296. return prepare_data_loader(
  297. dataloader,
  298. accelerator.device,
  299. num_processes=mpu.get_data_parallel_world_size(),
  300. process_index=mpu.get_data_parallel_rank(),
  301. split_batches=False,
  302. put_on_device=True,
  303. rng_types=accelerator.rng_types.copy(),
  304. dispatch_batches=accelerator.dispatch_batches,
  305. )
  306. else:
  307. if args.consumed_samples is not None:
  308. (
  309. args.consumed_train_samples,
  310. args.consumed_valid_samples,
  311. args.consumed_test_samples,
  312. ) = args.consumed_samples
  313. else:
  314. args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0
  315. args.micro_batch_size = args.micro_batch_size * args.num_micro_batches
  316. # In order to be compatible with data in transform format,
  317. # it needs to increase the size of mbs first,
  318. # and then split the large batch data into some mbs.
  319. (
  320. train_data_iterator,
  321. valid_data_iterator,
  322. test_data_iterator,
  323. ) = dataloader.build_train_valid_test_data_iterators(accelerator)
  324. args.micro_batch_size = args.micro_batch_size // args.num_micro_batches
  325. train_data_iterator = _handle_megatron_data_iterator(
  326. accelerator=accelerator, data_iterator=train_data_iterator
  327. )
  328. valid_data_iterator = _handle_megatron_data_iterator(
  329. accelerator=accelerator, data_iterator=valid_data_iterator
  330. )
  331. test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator)
  332. return train_data_iterator, valid_data_iterator, test_data_iterator
  333. # optimizer utilities
  334. class MegatronLMOptimizerWrapper(AcceleratedOptimizer):
  335. def __init__(self, optimizer):
  336. super().__init__(optimizer, device_placement=False, scaler=None)
  337. def zero_grad(self, set_to_none=None):
  338. pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
  339. def step(self):
  340. pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
  341. @property
  342. def step_was_skipped(self):
  343. """Whether or not the optimizer step was done, or skipped because of gradient overflow."""
  344. return self.optimizer.skipped_iter
  345. def prepare_optimizer(accelerator, model):
  346. accelerator.print("Preparing optimizer")
  347. args = get_args()
  348. return get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult)
  349. # scheduler utilities
  350. class MegatronLMDummyScheduler:
  351. """
  352. Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
  353. loop when scheduler config is specified in the deepspeed config file.
  354. Args:
  355. optimizer (`torch.optim.optimizer.Optimizer`):
  356. The optimizer to wrap.
  357. total_num_steps (int):
  358. Total number of steps.
  359. warmup_num_steps (int):
  360. Number of steps for warmup.
  361. **kwargs (additional keyword arguments, *optional*):
  362. Other arguments.
  363. """
  364. def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs):
  365. self.optimizer = optimizer
  366. self.total_num_steps = total_num_steps
  367. self.warmup_num_steps = warmup_num_steps
  368. self.kwargs = kwargs
  369. class MegatronLMSchedulerWrapper(AcceleratedScheduler):
  370. def __init__(self, scheduler, optimizers):
  371. super().__init__(scheduler, optimizers)
  372. def step(self, *args, **kwargs):
  373. return # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
  374. def prepare_scheduler(accelerator, optimizer, scheduler):
  375. accelerator.print("Preparing scheduler")
  376. scheduler = get_optimizer_param_scheduler(optimizer)
  377. return scheduler
  378. class AbstractTrainStep(ABC):
  379. """Abstract class for batching, forward pass and loss handler."""
  380. def __init__(self, name):
  381. super().__init__()
  382. self.name = name
  383. def get_batch_func(self, accelerator, megatron_dataset_flag):
  384. pass
  385. def get_forward_step_func(self):
  386. pass
  387. def get_loss_func(self, accelerator):
  388. pass
  389. class BertTrainStep(AbstractTrainStep):
  390. """
  391. Bert train step class.
  392. Args:
  393. args (`argparse.Namespace`): Megatron-LM arguments.
  394. """
  395. def __init__(self, accelerator, args):
  396. super().__init__("BertTrainStep")
  397. self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
  398. self.loss_func = self.get_loss_func(accelerator, args.pretraining_flag, args.num_labels)
  399. self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head)
  400. if not args.model_return_dict:
  401. self.model_output_class = None
  402. else:
  403. from transformers.modeling_outputs import SequenceClassifierOutput
  404. self.model_output_class = SequenceClassifierOutput
  405. def get_batch_func(self, accelerator, megatron_dataset_flag):
  406. def get_batch_megatron(data_iterator):
  407. """Build the batch."""
  408. # Items and their type.
  409. keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
  410. datatype = torch.int64
  411. # Broadcast data.
  412. if data_iterator is not None:
  413. data = next(data_iterator)
  414. else:
  415. data = None
  416. data_b = tensor_parallel.broadcast_data(keys, data, datatype)
  417. # Unpack.
  418. tokens = data_b["text"].long()
  419. types = data_b["types"].long()
  420. sentence_order = data_b["is_random"].long()
  421. loss_mask = data_b["loss_mask"].float()
  422. lm_labels = data_b["labels"].long()
  423. padding_mask = data_b["padding_mask"].long()
  424. return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
  425. def get_batch_transformer(data_iterator):
  426. """Build the batch."""
  427. data = next(data_iterator)
  428. data = send_to_device(data, torch.cuda.current_device())
  429. # Unpack.
  430. tokens = data["input_ids"].long()
  431. padding_mask = data["attention_mask"].long()
  432. if "token_type_ids" in data:
  433. types = data["token_type_ids"].long()
  434. else:
  435. types = None
  436. if "labels" in data:
  437. lm_labels = data["labels"].long()
  438. loss_mask = (data["labels"] != -100).to(torch.float)
  439. else:
  440. lm_labels = None
  441. loss_mask = None
  442. if "next_sentence_label" in data:
  443. sentence_order = data["next_sentence_label"].long()
  444. else:
  445. sentence_order = None
  446. return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
  447. if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
  448. return accelerator.state.megatron_lm_plugin.custom_get_batch_function
  449. if megatron_dataset_flag:
  450. try:
  451. # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
  452. from pretrain_bert import get_batch
  453. return get_batch
  454. except ImportError:
  455. pass
  456. return get_batch_megatron
  457. else:
  458. return get_batch_transformer
  459. def get_loss_func(self, accelerator, pretraining_flag, num_labels):
  460. def loss_func_pretrain(loss_mask, sentence_order, output_tensor):
  461. lm_loss_, sop_logits = output_tensor
  462. lm_loss_ = lm_loss_.float()
  463. loss_mask = loss_mask.float()
  464. lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
  465. if sop_logits is not None:
  466. sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
  467. sop_loss = sop_loss.float()
  468. loss = lm_loss + sop_loss
  469. averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
  470. return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]}
  471. else:
  472. loss = lm_loss
  473. averaged_losses = average_losses_across_data_parallel_group([lm_loss])
  474. return loss, {"lm loss": averaged_losses[0]}
  475. def loss_func_finetune(labels, logits):
  476. if num_labels == 1:
  477. # We are doing regression
  478. loss_fct = MSELoss()
  479. loss = loss_fct(logits.view(-1), labels.view(-1))
  480. elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
  481. loss_fct = CrossEntropyLoss()
  482. loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
  483. else:
  484. loss_fct = BCEWithLogitsLoss()
  485. loss = loss_fct(logits, labels)
  486. averaged_losses = average_losses_across_data_parallel_group([loss])
  487. return loss, {"loss": averaged_losses[0]}
  488. if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
  489. return accelerator.state.megatron_lm_plugin.custom_loss_function
  490. if pretraining_flag:
  491. return loss_func_pretrain
  492. else:
  493. return loss_func_finetune
  494. def get_forward_step_func(self, pretraining_flag, bert_binary_head):
  495. def forward_step(data_iterator, model):
  496. """Forward step."""
  497. tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator)
  498. if not bert_binary_head:
  499. types = None
  500. # Forward pass through the model.
  501. if pretraining_flag:
  502. output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels)
  503. return output_tensor, partial(self.loss_func, loss_mask, sentence_order)
  504. else:
  505. logits = model(tokens, padding_mask, tokentype_ids=types)
  506. return logits, partial(self.loss_func, labels)
  507. return forward_step
  508. class GPTTrainStep(AbstractTrainStep):
  509. """
  510. GPT train step class.
  511. Args:
  512. args (`argparse.Namespace`): Megatron-LM arguments.
  513. """
  514. def __init__(self, accelerator, args):
  515. super().__init__("GPTTrainStep")
  516. self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
  517. self.loss_func = self.get_loss_func(accelerator)
  518. self.forward_step = self.get_forward_step_func()
  519. self.eod_token = args.padded_vocab_size - 1
  520. if args.vocab_file is not None:
  521. tokenizer = get_tokenizer()
  522. self.eod_token = tokenizer.eod
  523. self.reset_position_ids = args.reset_position_ids
  524. self.reset_attention_mask = args.reset_attention_mask
  525. self.eod_mask_loss = args.eod_mask_loss
  526. if not args.model_return_dict:
  527. self.model_output_class = None
  528. else:
  529. from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
  530. self.model_output_class = CausalLMOutputWithCrossAttentions
  531. def get_batch_func(self, accelerator, megatron_dataset_flag):
  532. def get_batch_megatron(data_iterator):
  533. """Generate a batch"""
  534. # Items and their type.
  535. keys = ["text"]
  536. datatype = torch.int64
  537. # Broadcast data.
  538. if data_iterator is not None:
  539. data = next(data_iterator)
  540. else:
  541. data = None
  542. data_b = tensor_parallel.broadcast_data(keys, data, datatype)
  543. # Unpack.
  544. tokens_ = data_b["text"].long()
  545. labels = tokens_[:, 1:].contiguous()
  546. tokens = tokens_[:, :-1].contiguous()
  547. # Get the masks and position ids.
  548. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
  549. tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss
  550. )
  551. return tokens, labels, loss_mask, attention_mask, position_ids
  552. def get_batch_transformer(data_iterator):
  553. data = next(data_iterator)
  554. data = {"input_ids": data["input_ids"]}
  555. data = send_to_device(data, torch.cuda.current_device())
  556. tokens_ = data["input_ids"].long()
  557. padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token
  558. tokens_ = torch.concat([tokens_, padding], dim=1)
  559. labels = tokens_[:, 1:].contiguous()
  560. tokens = tokens_[:, :-1].contiguous()
  561. # Get the masks and position ids.
  562. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
  563. tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, True
  564. )
  565. return tokens, labels, loss_mask, attention_mask, position_ids
  566. if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
  567. return accelerator.state.megatron_lm_plugin.custom_get_batch_function
  568. if megatron_dataset_flag:
  569. try:
  570. # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
  571. from pretrain_gpt import get_batch
  572. return get_batch
  573. except ImportError:
  574. pass
  575. return get_batch_megatron
  576. else:
  577. return get_batch_transformer
  578. def get_loss_func(self, accelerator):
  579. args = get_args()
  580. def loss_func(loss_mask, output_tensor):
  581. if args.return_logits:
  582. losses, logits = output_tensor
  583. else:
  584. losses = output_tensor
  585. losses = losses.float()
  586. loss_mask = loss_mask.view(-1).float()
  587. if args.context_parallel_size > 1:
  588. loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
  589. torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
  590. loss = loss[0] / loss[1]
  591. else:
  592. loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
  593. # Check individual rank losses are not NaN prior to DP all-reduce.
  594. if args.check_for_nan_in_loss_and_grad:
  595. global_rank = torch.distributed.get_rank()
  596. assert not loss.isnan(), (
  597. f"Rank {global_rank}: found NaN in local forward loss calculation. "
  598. f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}"
  599. )
  600. # Reduce loss for logging.
  601. averaged_loss = average_losses_across_data_parallel_group([loss])
  602. output_dict = {"lm loss": averaged_loss[0]}
  603. if args.return_logits:
  604. output_dict.update({"logits": logits})
  605. return loss, output_dict
  606. if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
  607. return accelerator.state.megatron_lm_plugin.custom_loss_function
  608. return loss_func
  609. def get_forward_step_func(self):
  610. def forward_step(data_iterator, model):
  611. """Forward step."""
  612. # Get the batch.
  613. tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator)
  614. output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
  615. return output_tensor, partial(self.loss_func, loss_mask)
  616. return forward_step
  617. class T5TrainStep(AbstractTrainStep):
  618. """
  619. T5 train step class.
  620. Args:
  621. args (`argparse.Namespace`): Megatron-LM arguments.
  622. """
  623. def __init__(self, accelerator, args):
  624. super().__init__("T5TrainStep")
  625. self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
  626. self.loss_func = self.get_loss_func(accelerator)
  627. self.forward_step = self.get_forward_step_func()
  628. if not args.model_return_dict:
  629. self.model_output_class = None
  630. else:
  631. from transformers.modeling_outputs import Seq2SeqLMOutput
  632. self.model_output_class = Seq2SeqLMOutput
  633. @staticmethod
  634. def attn_mask_postprocess(attention_mask):
  635. # We create a 3D attention mask from a 2D tensor mask.
  636. # [b, 1, s]
  637. attention_mask_b1s = attention_mask.unsqueeze(1)
  638. # [b, s, 1]
  639. attention_mask_bs1 = attention_mask.unsqueeze(2)
  640. # [b, s, s]
  641. attention_mask_bss = attention_mask_b1s * attention_mask_bs1
  642. # Convert attention mask to binary:
  643. extended_attention_mask = attention_mask_bss < 0.5
  644. return extended_attention_mask
  645. @staticmethod
  646. def get_decoder_mask(seq_length, device):
  647. attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device))
  648. attention_mask = attention_mask < 0.5
  649. return attention_mask
  650. @staticmethod
  651. def get_enc_dec_mask(attention_mask, dec_seq_length, device):
  652. batch_size, _ = attention_mask.shape
  653. # We create a 3D attention mask from a 2D tensor mask.
  654. # [b, 1, s]
  655. attention_mask_b1s = attention_mask.unsqueeze(1)
  656. # [b, s, 1]
  657. attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device)
  658. attention_mask_bss = attention_mask_bs1 * attention_mask_b1s
  659. extended_attention_mask = attention_mask_bss < 0.5
  660. return extended_attention_mask
  661. def get_batch_func(self, accelerator, megatron_dataset_flag):
  662. def get_batch_megatron(data_iterator):
  663. """Build the batch."""
  664. keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"]
  665. datatype = torch.int64
  666. # Broadcast data.
  667. if data_iterator is not None:
  668. data = next(data_iterator)
  669. else:
  670. data = None
  671. data_b = tensor_parallel.broadcast_data(keys, data, datatype)
  672. # Unpack.
  673. tokens_enc = data_b["text_enc"].long()
  674. tokens_dec = data_b["text_dec"].long()
  675. labels = data_b["labels"].long()
  676. loss_mask = data_b["loss_mask"].float()
  677. enc_mask = data_b["enc_mask"] < 0.5
  678. dec_mask = data_b["dec_mask"] < 0.5
  679. enc_dec_mask = data_b["enc_dec_mask"] < 0.5
  680. return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
  681. def get_batch_transformer(data_iterator):
  682. """Build the batch."""
  683. data = next(data_iterator)
  684. data = send_to_device(data, torch.cuda.current_device())
  685. tokens_enc = data["input_ids"].long()
  686. labels = data["labels"].long()
  687. loss_mask = (labels != -100).to(torch.float)
  688. if "decoder_input_ids" in data:
  689. tokens_dec = data["decoder_input_ids"].long()
  690. else:
  691. tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long)
  692. tokens_dec[..., 1:] = labels[..., :-1].clone()
  693. tokens_dec[..., 0] = 0
  694. tokens_dec.masked_fill_(tokens_dec == -100, 0)
  695. enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long())
  696. dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device)
  697. enc_dec_mask = T5TrainStep.get_enc_dec_mask(
  698. data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device
  699. )
  700. return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
  701. if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
  702. return accelerator.state.megatron_lm_plugin.custom_get_batch_function
  703. if megatron_dataset_flag:
  704. try:
  705. # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
  706. from pretrain_t5 import get_batch
  707. return get_batch
  708. except ImportError:
  709. pass
  710. return get_batch_megatron
  711. else:
  712. return get_batch_transformer
  713. def get_loss_func(self, accelerator):
  714. def loss_func(loss_mask, output_tensor):
  715. lm_loss_ = output_tensor.float()
  716. lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
  717. loss = lm_loss
  718. averaged_losses = average_losses_across_data_parallel_group([lm_loss])
  719. return loss, {"lm loss": averaged_losses[0]}
  720. if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
  721. return accelerator.state.megatron_lm_plugin.custom_loss_function
  722. return loss_func
  723. def get_forward_step_func(self):
  724. def forward_step(data_iterator, model):
  725. """Forward step."""
  726. # Get the batch.
  727. tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch(
  728. data_iterator
  729. )
  730. # Forward model lm_labels
  731. output_tensor = model(
  732. tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels
  733. )
  734. return output_tensor, partial(self.loss_func, loss_mask)
  735. return forward_step
  736. def finish_mpu_init():
  737. # torch.distributed initialization
  738. args = get_args()
  739. # Pytorch distributed.
  740. _initialize_distributed()
  741. # Random seeds for reproducibility.
  742. if args.rank == 0:
  743. print(f"> setting random seeds to {args.seed} ...")
  744. _set_random_seed(args.seed, args.data_parallel_random_init)
  745. # initialize megatron setup
  746. def initialize(accelerator, extra_args_provider=None, args_defaults={}):
  747. accelerator.print("Initializing Megatron-LM")
  748. assert torch.cuda.is_available(), "Megatron requires CUDA."
  749. # Parse arguments
  750. args = parse_args(extra_args_provider, ignore_unknown_args=True)
  751. # Set defaults
  752. for key, value in args_defaults.items():
  753. if getattr(args, key, None) is not None:
  754. if args.rank == 0:
  755. print(
  756. f"WARNING: overriding default arguments for {key}:{getattr(args, key)} with {key}:{value}",
  757. flush=True,
  758. )
  759. setattr(args, key, value)
  760. if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
  761. assert args.load is not None, "--use-checkpoints-args requires --load argument"
  762. load_args_from_checkpoint(args)
  763. validate_args(args)
  764. # set global args, build tokenizer, and set adlr-autoresume,
  765. # tensorboard-writer, and timers.
  766. set_global_variables(args)
  767. # Megatron's MPU is the master. Complete initialization right away.
  768. finish_mpu_init()
  769. # Autoresume.
  770. _init_autoresume()
  771. # Compile dependencies.
  772. _compile_dependencies()
  773. # Set pytorch JIT layer fusion options and warmup JIT functions.
  774. set_jit_fusion_options()
  775. args = get_args()
  776. if getattr(args, "padded_vocab_size", None) is None:
  777. args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args)
  778. if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2:
  779. args.bert_binary_head = True
  780. else:
  781. args.bert_binary_head = False
  782. args.iteration = 0
  783. class MegatronEngine(torch.nn.Module):
  784. """
  785. Megatron-LM model wrapper
  786. Args:
  787. accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use.
  788. model: Megatron-LM model
  789. optimizer: Megatron-LM optimizer
  790. lr_scheduler: Megatron-LM lr scheduler
  791. """
  792. def __init__(self, accelerator, model, optimizer, scheduler):
  793. super().__init__()
  794. self.module = model
  795. self.base_model = model[0]
  796. self.optimizer = optimizer
  797. self.scheduler = scheduler
  798. args = get_args()
  799. if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None:
  800. self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class(
  801. args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs
  802. )
  803. elif args.model_type_name == "bert":
  804. self.train_step_handler = BertTrainStep(accelerator, args)
  805. elif args.model_type_name == "gpt":
  806. self.train_step_handler = GPTTrainStep(accelerator, args)
  807. elif args.model_type_name == "t5":
  808. self.train_step_handler = T5TrainStep(accelerator, args)
  809. else:
  810. raise ValueError(f"Unsupported model type: {args.model_type_name}")
  811. self.optimizer.skipped_iter = False
  812. # Tracking loss.
  813. self.total_loss_dict = {}
  814. self.eval_total_loss_dict = {}
  815. self.iteration = 0
  816. self.report_memory_flag = True
  817. self.num_floating_point_operations_so_far = 0
  818. self.module_config = None
  819. if args.tensorboard_dir is not None:
  820. write_args_to_tensorboard()
  821. def get_module_config(self):
  822. args = get_args()
  823. config = get_model_config(self.module[0])
  824. # Setup some training config params
  825. config.grad_scale_func = self.optimizer.scale_loss
  826. if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce:
  827. assert config.no_sync_func is None, (
  828. "When overlap_grad_reduce is True, config.no_sync_func must be None; "
  829. "a custom no_sync_func is not supported when overlapping grad-reduce"
  830. )
  831. config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module]
  832. if len(self.module) == 1:
  833. config.no_sync_func = config.no_sync_func[0]
  834. if args.delay_grad_reduce:
  835. config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module]
  836. if len(self.module) == 1:
  837. config.grad_sync_func = config.grad_sync_func[0]
  838. if args.overlap_param_gather and args.delay_param_gather:
  839. config.param_sync_func = [
  840. lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module))
  841. ]
  842. if len(self.module) == 1:
  843. config.param_sync_func = config.param_sync_func[0]
  844. config.finalize_model_grads_func = finalize_model_grads
  845. return config
  846. def train(self):
  847. for model_module in self.module:
  848. model_module.train()
  849. if self.module_config is None:
  850. self.module_config = self.get_module_config()
  851. self.log_eval_results()
  852. def eval(self):
  853. for model_module in self.module:
  854. model_module.eval()
  855. if self.module_config is None:
  856. self.module_config = self.get_module_config()
  857. def get_batch_data_iterator(self, batch_data):
  858. args = get_args()
  859. data_chunks = []
  860. if len(batch_data) > 0:
  861. if args.num_micro_batches > 1:
  862. for i in range(0, args.num_micro_batches):
  863. data_chunks.append(
  864. {
  865. k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size]
  866. for k, v in batch_data.items()
  867. }
  868. )
  869. else:
  870. data_chunks = [batch_data]
  871. if len(self.module) > 1:
  872. batch_data_iterator = (
  873. [iter(data_chunks) for _ in range(len(self.module))]
  874. if len(batch_data) > 0
  875. else [None] * len(self.module)
  876. )
  877. else:
  878. batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None
  879. return batch_data_iterator
  880. def train_step(self, **batch_data):
  881. """
  882. Training step for Megatron-LM
  883. Args:
  884. batch_data (:obj:`dict`): The batch data to train on.
  885. """
  886. batch_data_iterator = self.get_batch_data_iterator(batch_data)
  887. loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
  888. forward_step_func=self.train_step_handler.forward_step,
  889. data_iterator=batch_data_iterator,
  890. model=self.module,
  891. optimizer=self.optimizer,
  892. opt_param_scheduler=self.scheduler,
  893. config=self.module_config,
  894. )
  895. self.optimizer.skipped_iter = skipped_iter == 1
  896. return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
  897. def eval_step(self, **batch_data):
  898. """
  899. Evaluation step for Megatron-LM
  900. Args:
  901. batch_data (:obj:`dict`): The batch data to evaluate on.
  902. """
  903. args = get_args()
  904. batch_data_iterator = self.get_batch_data_iterator(batch_data)
  905. forward_backward_func = get_forward_backward_func()
  906. loss_dicts = forward_backward_func(
  907. forward_step_func=self.train_step_handler.forward_step,
  908. data_iterator=batch_data_iterator,
  909. model=self.module,
  910. num_microbatches=get_num_microbatches(),
  911. seq_length=args.seq_length,
  912. micro_batch_size=args.micro_batch_size,
  913. forward_only=True,
  914. )
  915. # Empty unused memory
  916. if args.empty_unused_memory_level >= 1:
  917. torch.cuda.empty_cache()
  918. args.consumed_valid_samples += (
  919. mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
  920. )
  921. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  922. # Average loss across microbatches.
  923. loss_reduced = {}
  924. for key in loss_dicts[0]:
  925. losses_reduced_for_key = [x[key] for x in loss_dicts]
  926. if len(losses_reduced_for_key[0].shape) == 0:
  927. loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
  928. else:
  929. loss_reduced[key] = torch.concat(losses_reduced_for_key)
  930. return loss_reduced
  931. return {}
  932. def forward(self, **batch_data):
  933. # During training, we use train_step()
  934. # model(**batch_data) performs following operations by delegating it to `self.train_step`:
  935. # 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism
  936. # 2. Set grad to zero.
  937. # 3. forward pass and backward pass using Pipeline Parallelism
  938. # 4. Empty unused memory.
  939. # 5. Reduce gradients.
  940. # 6. Update parameters.
  941. # 7. Gather params when using Distributed Optimizer (Data Parallelism).
  942. # 8. Update learning rate if scheduler is specified.
  943. # 9. Empty unused memory.
  944. # 10. Average loss across microbatches and across DP ranks.
  945. #
  946. # During evaluation, we use eval_step()
  947. args = get_args()
  948. if self.module[0].training:
  949. loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data)
  950. self.iteration += 1
  951. batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
  952. args.consumed_train_samples += batch_size
  953. self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size)
  954. if args.tensorboard_dir is not None:
  955. # Logging.
  956. loss_scale = self.optimizer.get_loss_scale().item()
  957. params_norm = None
  958. if args.log_params_norm:
  959. params_norm = calc_params_l2_norm(self.model)
  960. self.report_memory_flag = training_log(
  961. loss_dict,
  962. self.total_loss_dict,
  963. self.optimizer.param_groups[0]["lr"],
  964. self.iteration,
  965. loss_scale,
  966. self.report_memory_flag,
  967. skipped_iter,
  968. grad_norm,
  969. params_norm,
  970. num_zeros_in_grad,
  971. )
  972. else:
  973. loss_dict = self.eval_step(**batch_data)
  974. if args.tensorboard_dir is not None:
  975. for key in loss_dict:
  976. self.eval_total_loss_dict[key] = (
  977. self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
  978. )
  979. self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get(
  980. key + "_num_iters", torch.cuda.FloatTensor([0.0])
  981. ) + torch.cuda.FloatTensor([1.0])
  982. loss = torch.tensor(0.0, device=torch.cuda.current_device())
  983. for key in loss_dict:
  984. if len(loss_dict[key].shape) == 0:
  985. loss += loss_dict[key]
  986. logits = None
  987. if "logits" in loss_dict:
  988. logits = loss_dict["logits"]
  989. if self.train_step_handler.model_output_class is not None:
  990. return self.train_step_handler.model_output_class(loss=loss, logits=logits)
  991. return loss
  992. def log_eval_results(self):
  993. args = get_args()
  994. if args.tensorboard_dir is None or self.iteration == 0:
  995. return
  996. args = get_args()
  997. writer = get_tensorboard_writer()
  998. string = f"validation loss at iteration {self.iteration} | "
  999. for key in self.eval_total_loss_dict:
  1000. if key.endswith("_num_iters"):
  1001. continue
  1002. value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"]
  1003. string += f"{key} value: {value} | "
  1004. ppl = math.exp(min(20, value.item()))
  1005. if args.pretraining_flag:
  1006. string += f"{key} PPL: {ppl} | "
  1007. if writer:
  1008. writer.add_scalar(f"{key} validation", value.item(), self.iteration)
  1009. if args.pretraining_flag:
  1010. writer.add_scalar(f"{key} validation ppl", ppl, self.iteration)
  1011. length = len(string) + 1
  1012. print_rank_last("-" * length)
  1013. print_rank_last(string)
  1014. print_rank_last("-" * length)
  1015. self.eval_total_loss_dict = {}
  1016. def save_checkpoint(self, output_dir):
  1017. self.log_eval_results()
  1018. args = get_args()
  1019. args.save = output_dir
  1020. torch.distributed.barrier()
  1021. save_checkpoint(
  1022. self.iteration,
  1023. self.module,
  1024. self.optimizer,
  1025. self.scheduler,
  1026. num_floating_point_operations_so_far=self.num_floating_point_operations_so_far,
  1027. )
  1028. torch.distributed.barrier()
  1029. def load_checkpoint(self, input_dir):
  1030. args = get_args()
  1031. args.load = input_dir
  1032. args.consumed_train_samples = 0
  1033. args.consumed_valid_samples = 0
  1034. torch.distributed.barrier()
  1035. iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler)
  1036. torch.distributed.barrier()
  1037. self.iteration = iteration
  1038. self.num_floating_point_operations_so_far = num_floating_point_operations_so_far
  1039. if args.fp16 and self.iteration == 0:
  1040. self.optimizer.reload_model_params()
  1041. def megatron_generate(
  1042. self,
  1043. inputs,
  1044. attention_mask=None,
  1045. max_length=None,
  1046. max_new_tokens=None,
  1047. num_beams=None,
  1048. temperature=None,
  1049. top_k=None,
  1050. top_p=None,
  1051. length_penalty=None,
  1052. **kwargs,
  1053. ):
  1054. """
  1055. Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along
  1056. with sampling. Refer the Megatron-LM repo for more details
  1057. Args:
  1058. inputs (torch.Tensor): input ids
  1059. attention_mask (torch.Tensor, optional): attention mask. Defaults to None.
  1060. max_length (int, optional): max length of the generated sequence. Defaults to None.
  1061. Either this or max_new_tokens should be provided.
  1062. max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None.
  1063. Either this or max_length should be provided.
  1064. num_beams (int, optional): number of beams to use for beam search. Defaults to None.
  1065. temperature (float, optional): temperature for sampling. Defaults to 1.0.
  1066. top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0.
  1067. top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0.
  1068. length_penalty (float, optional): length penalty for beam search. Defaults to None.
  1069. kwargs: additional key-value arguments
  1070. """
  1071. # checking if required arguments are passed
  1072. args = get_args()
  1073. if args.model_type_name != "gpt":
  1074. raise NotImplementedError("Generate method is not implemented for this model")
  1075. if args.data_parallel_size > 1:
  1076. raise ValueError("Generate method requires data parallelism to be 1")
  1077. if args.sequence_parallel:
  1078. raise ValueError("Generate method requires sequence parallelism to be False")
  1079. if args.recompute_granularity is not None:
  1080. raise ValueError("Checkpoint activations cannot be set for inference")
  1081. if args.vocab_file is None:
  1082. raise ValueError("Vocab file is required for inference")
  1083. # Prepare inputs
  1084. if max_length is None and max_new_tokens is None:
  1085. raise ValueError("`max_length` or `max_new_tokens` are required for inference")
  1086. if temperature is None:
  1087. temperature = 1.0
  1088. elif not (0.0 < temperature <= 100.0):
  1089. raise ValueError("temperature must be a positive number less than or equal to 100.0")
  1090. if top_k is None:
  1091. top_k = 0
  1092. elif not (0 <= top_k <= 1000):
  1093. raise ValueError("top_k must be a positive number less than or equal to 1000")
  1094. if top_p is None:
  1095. top_p = 0.0
  1096. elif top_p > 0.0 and top_k > 0.0:
  1097. raise ValueError("top_p and top_k sampling cannot be set together")
  1098. else:
  1099. if not (0.0 <= top_p <= 1.0):
  1100. raise ValueError("top_p must be less than or equal to 1.0")
  1101. top_p_decay = kwargs.get("top_p_decay", 0.0)
  1102. if not (0.0 <= top_p_decay <= 1.0):
  1103. raise ValueError("top_p_decay must be less than or equal to 1.0")
  1104. top_p_bound = kwargs.get("top_p_bound", 0.0)
  1105. if not (0.0 <= top_p_bound <= 1.0):
  1106. raise ValueError("top_p_bound must be less than or equal to 1.0")
  1107. add_BOS = kwargs.get("add_BOS", False)
  1108. if not (isinstance(add_BOS, bool)):
  1109. raise ValueError("add_BOS must be a boolean")
  1110. beam_width = num_beams
  1111. if beam_width is not None:
  1112. if not isinstance(beam_width, int):
  1113. raise ValueError("beam_width must be an integer")
  1114. if beam_width < 1:
  1115. raise ValueError("beam_width must be greater than 0")
  1116. if inputs.shape[0] > 1:
  1117. return "When doing beam_search, batch size must be 1"
  1118. tokenizer = get_tokenizer()
  1119. stop_token = kwargs.get("stop_token", tokenizer.eod)
  1120. if stop_token is not None:
  1121. if not isinstance(stop_token, int):
  1122. raise ValueError("stop_token must be an integer")
  1123. if length_penalty is None:
  1124. length_penalty = 1.0
  1125. sizes_list = None
  1126. prompts_tokens_tensor = None
  1127. prompts_length_tensor = None
  1128. if torch.distributed.get_rank() == 0:
  1129. # Get the prompts length.
  1130. if attention_mask is None:
  1131. prompts_length_tensor = torch.cuda.LongTensor([inputs.shape[1]] * inputs.shape[0])
  1132. else:
  1133. prompts_length_tensor = attention_mask.sum(axis=-1).cuda()
  1134. if max_new_tokens is None:
  1135. max_new_tokens = max_length - inputs.shape[1]
  1136. if max_new_tokens <= 0:
  1137. raise ValueError("max_new_tokens must be greater than 0")
  1138. if add_BOS:
  1139. max_length = max_new_tokens + inputs.shape[1] + 1
  1140. # making sure that `max_length` is a multiple of 4 to leverage fused kernels
  1141. max_length = 4 * math.ceil(max_length / 4)
  1142. max_new_tokens = max_length - (inputs.shape[1] + 1)
  1143. padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
  1144. prompts_tokens_tensor = torch.concat(
  1145. [torch.unsqueeze(padding[:, 0], axis=-1), inputs.cuda(), padding], axis=-1
  1146. )
  1147. else:
  1148. # making sure that `max_length` is a multiple of 4 to leverage fused kernels
  1149. max_length = max_new_tokens + inputs.shape[1]
  1150. max_length = 4 * math.ceil(max_length / 4)
  1151. max_new_tokens = max_length - inputs.shape[1]
  1152. padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
  1153. prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
  1154. # We need the sizes of these tensors for the broadcast
  1155. sizes_list = [
  1156. prompts_tokens_tensor.size(0), # Batch size
  1157. prompts_tokens_tensor.size(1),
  1158. ] # Sequence length
  1159. # First, broadcast the sizes.
  1160. sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
  1161. # Now that we have the sizes, we can broadcast the tokens
  1162. # and length tensors.
  1163. sizes = sizes_tensor.tolist()
  1164. context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)
  1165. context_length_tensor = broadcast_tensor(sizes[0], torch.int64, tensor=prompts_length_tensor, rank=0)
  1166. # Run the inference
  1167. random_seed = kwargs.get("random_seed", 0)
  1168. torch.random.manual_seed(random_seed)
  1169. unwrapped_model = unwrap_model(self.base_model, (torchDDP, LocalDDP, Float16Module))
  1170. if beam_width is not None:
  1171. tokens, _ = beam_search_and_return_on_first_stage(
  1172. unwrapped_model,
  1173. context_tokens_tensor,
  1174. context_length_tensor,
  1175. beam_width,
  1176. stop_token=stop_token,
  1177. num_return_gen=1,
  1178. length_penalty=length_penalty,
  1179. )
  1180. else:
  1181. tokens, _, _ = generate_tokens_probs_and_return_on_first_stage(
  1182. unwrapped_model,
  1183. context_tokens_tensor,
  1184. context_length_tensor,
  1185. return_output_log_probs=False,
  1186. top_k=top_k,
  1187. top_p=top_p,
  1188. top_p_decay=top_p_decay,
  1189. top_p_bound=top_p_bound,
  1190. temperature=temperature,
  1191. use_eod_token_for_early_termination=True,
  1192. )
  1193. return tokens
  1194. # other utilities
  1195. def avg_losses_across_data_parallel_group(losses):
  1196. """
  1197. Average losses across data parallel group.
  1198. Args:
  1199. losses (List[Tensor]): List of losses to average across data parallel group.
  1200. """
  1201. return average_losses_across_data_parallel_group(losses)
  1202. def gather_across_data_parallel_groups(tensor):
  1203. """
  1204. Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks.
  1205. Args:
  1206. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  1207. The data to gather across data parallel ranks.
  1208. """
  1209. def _gpu_gather_one(tensor):
  1210. if tensor.ndim == 0:
  1211. tensor = tensor.clone()[None]
  1212. output_tensors = [
  1213. torch.empty_like(tensor)
  1214. for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group()))
  1215. ]
  1216. torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group())
  1217. return torch.cat(output_tensors, dim=0)
  1218. return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)