test_sync.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from copy import deepcopy
  15. import torch
  16. import torch.nn.functional as F
  17. from torch.optim import AdamW
  18. from torch.optim.lr_scheduler import LambdaLR
  19. from torch.utils.data import DataLoader
  20. from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
  21. from accelerate.state import GradientState
  22. from accelerate.test_utils import RegressionDataset, RegressionModel
  23. from accelerate.utils import DistributedType, set_seed
  24. def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs):
  25. for param, grad_param in zip(model_a.parameters(), model_b.parameters()):
  26. if not param.requires_grad:
  27. continue
  28. if not did_step:
  29. # Grads should not be in sync
  30. assert torch.allclose(param.grad, grad_param.grad, **kwargs) is False, (
  31. f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
  32. )
  33. else:
  34. # Grads should be in sync
  35. assert torch.allclose(param.grad, grad_param.grad, **kwargs) is True, (
  36. f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"
  37. )
  38. def step_model(model, input, target, accelerator, do_backward=True):
  39. model.train()
  40. output = model(input)
  41. loss = F.mse_loss(output, target.to(output.device))
  42. if not do_backward:
  43. loss /= accelerator.gradient_accumulation_steps
  44. loss.backward()
  45. else:
  46. accelerator.backward(loss)
  47. def get_training_setup(accelerator, sched=False):
  48. "Returns everything needed to perform basic training"
  49. set_seed(42)
  50. model = RegressionModel()
  51. ddp_model = deepcopy(model)
  52. dset = RegressionDataset(length=80)
  53. dataloader = DataLoader(dset, batch_size=16)
  54. model.to(accelerator.device)
  55. if sched:
  56. opt = AdamW(params=model.parameters(), lr=1e-3)
  57. ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3)
  58. sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65)
  59. ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65)
  60. # Make a copy of `model`
  61. if sched:
  62. ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader)
  63. else:
  64. ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
  65. if sched:
  66. return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched)
  67. return model, ddp_model, dataloader
  68. def test_noop_sync(accelerator):
  69. # Test when on a single CPU or GPU that the context manager does nothing
  70. model, ddp_model, dataloader = get_training_setup(accelerator)
  71. # Use a single batch
  72. ddp_input, ddp_target = next(iter(dataloader)).values()
  73. for iteration in range(3):
  74. # Gather the distributed inputs and targs for the base model
  75. input, target = accelerator.gather((ddp_input, ddp_target))
  76. input, target = input.to(accelerator.device), target.to(accelerator.device)
  77. # Perform our initial ground truth step in non "DDP"
  78. step_model(model, input, target, accelerator)
  79. # Do "gradient accumulation" (noop)
  80. if iteration % 2 == 0:
  81. # Accumulate grads locally
  82. with accelerator.no_sync(ddp_model):
  83. step_model(ddp_model, ddp_input, ddp_target, accelerator)
  84. else:
  85. # Sync grads
  86. step_model(ddp_model, ddp_input, ddp_target, accelerator)
  87. # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
  88. check_model_parameters(model, ddp_model, True, iteration)
  89. for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
  90. if not param.requires_grad:
  91. continue
  92. assert torch.allclose(param.grad, ddp_param.grad), (
  93. f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
  94. )
  95. # Shuffle ddp_input on each iteration
  96. torch.manual_seed(1337 + iteration)
  97. ddp_input = ddp_input[torch.randperm(len(ddp_input))]
  98. def test_distributed_sync(accelerator):
  99. # Test on distributed setup that context manager behaves properly
  100. model, ddp_model, dataloader = get_training_setup(accelerator)
  101. # Use a single batch
  102. ddp_input, ddp_target = next(iter(dataloader)).values()
  103. for iteration in range(3):
  104. # Gather the distributed inputs and targs for the base model
  105. input, target = accelerator.gather((ddp_input, ddp_target))
  106. input, target = input.to(accelerator.device), target.to(accelerator.device)
  107. # Perform our initial ground truth step in non "DDP"
  108. step_model(model, input, target, accelerator)
  109. # Do "gradient accumulation" (noop)
  110. if iteration % 2 == 0:
  111. # Accumulate grads locally
  112. with accelerator.no_sync(ddp_model):
  113. step_model(ddp_model, ddp_input, ddp_target, accelerator)
  114. else:
  115. # Sync grads
  116. step_model(ddp_model, ddp_input, ddp_target, accelerator)
  117. # DDP model and model should only be in sync when not (iteration % 2 == 0)
  118. for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
  119. if not param.requires_grad:
  120. continue
  121. if iteration % 2 == 0:
  122. # Grads should not be in sync
  123. assert torch.allclose(param.grad, ddp_param.grad) is False, (
  124. f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
  125. )
  126. else:
  127. # Grads should be in sync
  128. assert torch.allclose(param.grad, ddp_param.grad) is True, (
  129. f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
  130. )
  131. # Shuffle ddp_input on each iteration
  132. torch.manual_seed(1337 + iteration)
  133. ddp_input = ddp_input[torch.randperm(len(ddp_input))]
  134. def test_distributed_sync_multiple_fwd(accelerator):
  135. # Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards
  136. model, ddp_model, dataloader = get_training_setup(accelerator)
  137. # Do multiple forwards
  138. losses = []
  139. num_iterations = 3
  140. for iteration in range(num_iterations):
  141. ddp_input, ddp_target = next(iter(dataloader)).values()
  142. # Gather the distributed inputs and targs for the base model
  143. input, target = accelerator.gather((ddp_input, ddp_target))
  144. input, target = input.to(accelerator.device), target.to(accelerator.device)
  145. # Perform our initial ground truth step in non "DDP"
  146. step_model(model, input, target, accelerator)
  147. # Accumulate grads locally
  148. with accelerator.no_sync(ddp_model):
  149. ddp_output = ddp_model(ddp_input)
  150. loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device))
  151. losses.append(loss)
  152. # Do multiple backwards and sync only at the last backward
  153. for iteration in range(num_iterations):
  154. loss = losses[iteration]
  155. if iteration < num_iterations - 1:
  156. # Accumulate grads locally
  157. accelerator.backward(loss)
  158. # DDP model and model should only be in sync after last backward
  159. for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
  160. if not param.requires_grad:
  161. continue
  162. # Grads should not be in sync
  163. assert torch.allclose(param.grad, ddp_param.grad) is False, (
  164. f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
  165. )
  166. else:
  167. # Sync grads if last backward
  168. with accelerator.trigger_sync_in_backward(ddp_model):
  169. accelerator.backward(loss)
  170. # DDP model and model should only be in sync after last backward
  171. for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
  172. if not param.requires_grad:
  173. continue
  174. # Grads should be in sync
  175. assert torch.allclose(param.grad, ddp_param.grad) is True, (
  176. f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
  177. )
  178. def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False):
  179. gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
  180. dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
  181. accelerator = Accelerator(
  182. dataloader_config=dataloader_config,
  183. gradient_accumulation_plugin=gradient_accumulation_plugin,
  184. )
  185. # Test that context manager behaves properly
  186. model, ddp_model, dataloader = get_training_setup(accelerator)
  187. for iteration, batch in enumerate(dataloader):
  188. ddp_input, ddp_target = batch.values()
  189. # Gather the distributed inputs and targs for the base model
  190. input, target = accelerator.gather((ddp_input, ddp_target))
  191. input, target = input.to(accelerator.device), target.to(accelerator.device)
  192. # Perform our initial ground truth step in non "DDP"
  193. step_model(model, input, target, accelerator, False)
  194. # Do "gradient accumulation" (noop)
  195. with accelerator.accumulate(ddp_model):
  196. step_model(ddp_model, ddp_input, ddp_target, accelerator)
  197. # DDP model and model should only be in sync when not (iteration % 2 == 0)
  198. for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
  199. if not param.requires_grad:
  200. continue
  201. if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch:
  202. # Grads should be in sync
  203. assert torch.allclose(param.grad, ddp_param.grad) is True, (
  204. f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
  205. )
  206. else:
  207. # Grads should not be in sync
  208. assert torch.allclose(param.grad, ddp_param.grad) is False, (
  209. f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
  210. )
  211. # Shuffle ddp_input on each iteration
  212. torch.manual_seed(1337 + iteration)
  213. ddp_input = ddp_input[torch.randperm(len(ddp_input))]
  214. GradientState._reset_state()
  215. def test_gradient_accumulation_with_opt_and_scheduler(
  216. split_batches=False, dispatch_batches=False, sync_each_batch=False
  217. ):
  218. gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
  219. dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
  220. accelerator = Accelerator(
  221. dataloader_config=dataloader_config,
  222. gradient_accumulation_plugin=gradient_accumulation_plugin,
  223. )
  224. # Test that context manager behaves properly
  225. model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)
  226. for iteration, batch in enumerate(dataloader):
  227. ddp_input, ddp_target = batch.values()
  228. # Gather the distributed inputs and targs for the base model
  229. input, target = accelerator.gather((ddp_input, ddp_target))
  230. input, target = input.to(accelerator.device), target.to(accelerator.device)
  231. # Perform our initial ground truth step in non "DDP"
  232. model.train()
  233. ddp_model.train()
  234. step_model(model, input, target, accelerator, False)
  235. opt.step()
  236. if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)):
  237. if split_batches:
  238. sched.step()
  239. else:
  240. for _ in range(accelerator.num_processes):
  241. sched.step()
  242. # Perform gradient accumulation under wrapper
  243. with accelerator.accumulate(ddp_model):
  244. step_model(ddp_model, ddp_input, ddp_target, accelerator)
  245. ddp_opt.step()
  246. ddp_sched.step()
  247. # Learning rates should be the same
  248. assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"], (
  249. f"Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]['lr']}\nDDP opt: {ddp_opt.param_groups[0]['lr']}\n"
  250. )
  251. did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader))
  252. if accelerator.num_processes > 1:
  253. check_model_parameters(
  254. model,
  255. ddp_model,
  256. did_step or sync_each_batch, # syncs at each grad_accum interval of if sync_each_batch==True
  257. iteration,
  258. rtol=1e-3, # needs a relative tolerance due to roundoff errors
  259. )
  260. if did_step:
  261. opt.zero_grad() # flush gradients every accum step
  262. ddp_opt.zero_grad()
  263. # Shuffle ddp_input on each iteration
  264. torch.manual_seed(1337 + iteration)
  265. GradientState._reset_state()
  266. def test_dataloader_break():
  267. accelerator = Accelerator()
  268. first_dset = RegressionDataset(length=80)
  269. first_dataloader = DataLoader(first_dset, batch_size=16)
  270. second_dset = RegressionDataset(length=96)
  271. second_dataloader = DataLoader(second_dset, batch_size=16)
  272. first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
  273. assert accelerator.gradient_state.active_dataloader is None
  274. for iteration, _ in enumerate(first_dataloader):
  275. assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
  276. if iteration < len(first_dataloader) - 1:
  277. assert not accelerator.gradient_state.end_of_dataloader
  278. if iteration == 1:
  279. for batch_num, _ in enumerate(second_dataloader):
  280. assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader)
  281. if batch_num < len(second_dataloader) - 1:
  282. assert not accelerator.gradient_state.end_of_dataloader
  283. else:
  284. assert accelerator.gradient_state.end_of_dataloader
  285. else:
  286. assert accelerator.gradient_state.end_of_dataloader
  287. assert accelerator.gradient_state.active_dataloader is None
  288. def main():
  289. accelerator = Accelerator()
  290. state = accelerator.state
  291. if state.local_process_index == 0:
  292. print("**Test `accumulate` gradient accumulation with dataloader break**")
  293. if state.distributed_type != DistributedType.XLA:
  294. test_dataloader_break()
  295. if state.distributed_type == DistributedType.NO:
  296. if state.local_process_index == 0:
  297. print("**Test NOOP `no_sync` context manager**")
  298. test_noop_sync(accelerator)
  299. if state.distributed_type in (
  300. DistributedType.MULTI_GPU,
  301. DistributedType.MULTI_NPU,
  302. DistributedType.MULTI_MLU,
  303. DistributedType.MULTI_SDAA,
  304. DistributedType.MULTI_MUSA,
  305. DistributedType.MULTI_CPU,
  306. DistributedType.MULTI_HPU,
  307. ):
  308. if state.local_process_index == 0:
  309. print("**Test Distributed `no_sync` context manager**")
  310. test_distributed_sync(accelerator)
  311. if state.local_process_index == 0:
  312. print("**Test Distributed `no_sync` context manager with multiple forwards**")
  313. test_distributed_sync_multiple_fwd(accelerator)
  314. if state.distributed_type in (
  315. DistributedType.MULTI_GPU,
  316. DistributedType.MULTI_NPU,
  317. DistributedType.MULTI_MLU,
  318. DistributedType.MULTI_SDAA,
  319. DistributedType.MULTI_MUSA,
  320. DistributedType.MULTI_HPU,
  321. ):
  322. for split_batch in [True, False]:
  323. for dispatch_batches in [True, False]:
  324. for sync_each_batch in [True, False]:
  325. if state.local_process_index == 0:
  326. print(
  327. "**Test `accumulate` gradient accumulation, ",
  328. f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
  329. )
  330. test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch)
  331. # Currently will break on torch 2.0 +, need to investigate why
  332. if state.local_process_index == 0:
  333. print(
  334. "**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
  335. "`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**",
  336. )
  337. test_gradient_accumulation_with_opt_and_scheduler()
  338. if state.distributed_type in (
  339. DistributedType.MULTI_GPU,
  340. DistributedType.MULTI_NPU,
  341. DistributedType.MULTI_MLU,
  342. DistributedType.MULTI_SDAA,
  343. DistributedType.MULTI_MUSA,
  344. DistributedType.MULTI_HPU,
  345. ):
  346. for split_batch in [True, False]:
  347. for dispatch_batches in [True, False]:
  348. for sync_each_batch in [True, False]:
  349. if not split_batch and not dispatch_batches and not sync_each_batch:
  350. continue
  351. if state.local_process_index == 0:
  352. print(
  353. "**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
  354. f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
  355. )
  356. test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
  357. state.destroy_process_group()
  358. def _mp_fn(index):
  359. # For xla_spawn (TPUs)
  360. main()
  361. if __name__ == "__main__":
  362. main()