test_script.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import contextlib
  16. import io
  17. import math
  18. import time
  19. from copy import deepcopy
  20. from pathlib import Path
  21. import numpy as np
  22. import torch
  23. from torch.utils.data import DataLoader, Dataset
  24. from accelerate import Accelerator
  25. from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
  26. from accelerate.state import AcceleratorState
  27. from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
  28. from accelerate.utils import (
  29. DataLoaderConfiguration,
  30. DistributedType,
  31. gather,
  32. gather_object,
  33. is_bf16_available,
  34. is_cuda_available,
  35. is_datasets_available,
  36. is_fp16_available,
  37. is_hpu_available,
  38. is_ipex_available,
  39. is_mps_available,
  40. is_pytest_available,
  41. set_seed,
  42. synchronize_rng_states,
  43. )
  44. if is_hpu_available():
  45. ATOL = 1e-3
  46. RTOL = 1e-3
  47. else:
  48. ATOL = 1e-6
  49. RTOL = 1e-6
  50. def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False):
  51. "Creates a dataloader that can also use the `SeedableRandomSampler`"
  52. if use_seedable_sampler:
  53. # The SeedableRandomSampler is needed during distributed setups
  54. # for full reproducibility across processes with the `DataLoader`
  55. sampler = SeedableRandomSampler(
  56. generator=generator,
  57. data_source=train_set,
  58. num_samples=len(train_set),
  59. )
  60. return DataLoader(train_set, batch_size=batch_size, sampler=sampler)
  61. else:
  62. return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
  63. def print_main(state):
  64. print(f"Printing from the main process {state.process_index}")
  65. def print_local_main(state):
  66. print(f"Printing from the local main process {state.local_process_index}")
  67. def print_last(state):
  68. print(f"Printing from the last process {state.process_index}")
  69. def print_on(state, process_idx):
  70. print(f"Printing from process {process_idx}: {state.process_index}")
  71. def process_execution_check():
  72. accelerator = Accelerator()
  73. num_processes = accelerator.num_processes
  74. # Test main_process_first context manager
  75. path = Path("check_main_process_first.txt")
  76. with accelerator.main_process_first():
  77. if accelerator.is_main_process:
  78. time.sleep(0.1) # ensure main process takes longest
  79. with open(path, "a+") as f:
  80. f.write("Currently in the main process\n")
  81. else:
  82. with open(path, "a+") as f:
  83. f.write("Now on another process\n")
  84. accelerator.wait_for_everyone()
  85. if accelerator.is_main_process:
  86. with open(path) as f:
  87. text = "".join(f.readlines())
  88. try:
  89. assert text.startswith("Currently in the main process\n"), "Main process was not first"
  90. if num_processes > 1:
  91. assert text.endswith("Now on another process\n"), "Main process was not first"
  92. assert text.count("Now on another process\n") == accelerator.num_processes - 1, (
  93. f"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}"
  94. )
  95. except AssertionError:
  96. path.unlink()
  97. raise
  98. if accelerator.is_main_process and path.exists():
  99. path.unlink()
  100. accelerator.wait_for_everyone()
  101. # Test the decorators
  102. f = io.StringIO()
  103. with contextlib.redirect_stdout(f):
  104. accelerator.on_main_process(print_main)(accelerator.state)
  105. result = f.getvalue().rstrip()
  106. if accelerator.is_main_process:
  107. assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0"
  108. else:
  109. assert f.getvalue().rstrip() == "", f'{result} != ""'
  110. f.truncate(0)
  111. f.seek(0)
  112. with contextlib.redirect_stdout(f):
  113. accelerator.on_local_main_process(print_local_main)(accelerator.state)
  114. if accelerator.is_local_main_process:
  115. assert f.getvalue().rstrip() == "Printing from the local main process 0"
  116. else:
  117. assert f.getvalue().rstrip() == ""
  118. f.truncate(0)
  119. f.seek(0)
  120. with contextlib.redirect_stdout(f):
  121. accelerator.on_last_process(print_last)(accelerator.state)
  122. if accelerator.is_last_process:
  123. assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}"
  124. else:
  125. assert f.getvalue().rstrip() == ""
  126. f.truncate(0)
  127. f.seek(0)
  128. for process_idx in range(num_processes):
  129. with contextlib.redirect_stdout(f):
  130. accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx)
  131. if accelerator.process_index == process_idx:
  132. assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}"
  133. else:
  134. assert f.getvalue().rstrip() == ""
  135. f.truncate(0)
  136. f.seek(0)
  137. def init_state_check():
  138. # Test we can instantiate this twice in a row.
  139. state = AcceleratorState()
  140. if state.local_process_index == 0:
  141. print("Testing, testing. 1, 2, 3.")
  142. print(state)
  143. def rng_sync_check():
  144. state = AcceleratorState()
  145. synchronize_rng_states(["torch"])
  146. assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU."
  147. if state.distributed_type == DistributedType.MULTI_GPU:
  148. synchronize_rng_states(["cuda"])
  149. assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU."
  150. elif state.distributed_type == DistributedType.MULTI_XPU:
  151. synchronize_rng_states(["xpu"])
  152. assert are_the_same_tensors(torch.xpu.get_rng_state()), "RNG states improperly synchronized on XPU."
  153. generator = torch.Generator()
  154. synchronize_rng_states(["generator"], generator=generator)
  155. assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."
  156. if state.local_process_index == 0:
  157. print("All rng are properly synched.")
  158. def dl_preparation_check():
  159. state = AcceleratorState()
  160. length = 32 * state.num_processes
  161. dl = DataLoader(range(length), batch_size=8)
  162. dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
  163. result = []
  164. for batch in dl:
  165. result.append(gather(batch))
  166. result = torch.cat(result)
  167. assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
  168. dl = DataLoader(range(length), batch_size=8)
  169. dl = prepare_data_loader(
  170. dl,
  171. state.device,
  172. state.num_processes,
  173. state.process_index,
  174. put_on_device=True,
  175. split_batches=True,
  176. )
  177. result = []
  178. for batch in dl:
  179. result.append(gather(batch))
  180. result = torch.cat(result)
  181. assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
  182. if state.process_index == 0:
  183. print("Non-shuffled dataloader passing.")
  184. dl = DataLoader(range(length), batch_size=8, shuffle=True)
  185. dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
  186. result = []
  187. for batch in dl:
  188. result.append(gather(batch))
  189. result = torch.cat(result).tolist()
  190. result.sort()
  191. assert result == list(range(length)), "Wrong shuffled dataloader result."
  192. dl = DataLoader(range(length), batch_size=8, shuffle=True)
  193. dl = prepare_data_loader(
  194. dl,
  195. state.device,
  196. state.num_processes,
  197. state.process_index,
  198. put_on_device=True,
  199. split_batches=True,
  200. )
  201. result = []
  202. for batch in dl:
  203. result.append(gather(batch))
  204. result = torch.cat(result).tolist()
  205. result.sort()
  206. assert result == list(range(length)), "Wrong shuffled dataloader result."
  207. if state.local_process_index == 0:
  208. print("Shuffled dataloader passing.")
  209. def central_dl_preparation_check():
  210. state = AcceleratorState()
  211. length = 32 * state.num_processes
  212. dl = DataLoader(range(length), batch_size=8)
  213. dl = prepare_data_loader(
  214. dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
  215. )
  216. result = []
  217. for batch in dl:
  218. result.append(gather(batch))
  219. result = torch.cat(result)
  220. assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
  221. dl = DataLoader(range(length), batch_size=8)
  222. dl = prepare_data_loader(
  223. dl,
  224. state.device,
  225. state.num_processes,
  226. state.process_index,
  227. put_on_device=True,
  228. split_batches=True,
  229. dispatch_batches=True,
  230. )
  231. result = []
  232. for batch in dl:
  233. result.append(gather(batch))
  234. result = torch.cat(result)
  235. assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
  236. if state.process_index == 0:
  237. print("Non-shuffled central dataloader passing.")
  238. dl = DataLoader(range(length), batch_size=8, shuffle=True)
  239. dl = prepare_data_loader(
  240. dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
  241. )
  242. result = []
  243. for batch in dl:
  244. result.append(gather(batch))
  245. result = torch.cat(result).tolist()
  246. result.sort()
  247. assert result == list(range(length)), "Wrong shuffled dataloader result."
  248. dl = DataLoader(range(length), batch_size=8, shuffle=True)
  249. dl = prepare_data_loader(
  250. dl,
  251. state.device,
  252. state.num_processes,
  253. state.process_index,
  254. put_on_device=True,
  255. split_batches=True,
  256. dispatch_batches=True,
  257. )
  258. result = []
  259. for batch in dl:
  260. result.append(gather(batch))
  261. result = torch.cat(result).tolist()
  262. result.sort()
  263. assert result == list(range(length)), "Wrong shuffled dataloader result."
  264. if state.local_process_index == 0:
  265. print("Shuffled central dataloader passing.")
  266. def custom_sampler_check():
  267. state = AcceleratorState()
  268. class CustomDataset(Dataset):
  269. def __init__(self, data):
  270. self.data = data
  271. def __len__(self):
  272. return len(self.data)
  273. def __getitem__(self, index):
  274. return self.data[index]
  275. class CustomBatchSampler:
  276. def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
  277. self.batch_size = batch_size
  278. self.data_index = np.arange(dataset_length)
  279. self.shuffle = shuffle
  280. def __iter__(self):
  281. num_batches = len(self)
  282. if self.shuffle:
  283. index = np.random.permutation(self.data_index)
  284. else:
  285. index = self.data_index
  286. output = np.array_split(index, num_batches)
  287. yield from output
  288. def __len__(self):
  289. return math.ceil(len(self.data_index) / self.batch_size)
  290. dataset = CustomDataset(range(32 * state.num_processes))
  291. sampler = CustomBatchSampler(len(dataset), batch_size=8)
  292. dl = DataLoader(dataset, batch_sampler=sampler)
  293. dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
  294. # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
  295. if hasattr(dl.batch_sampler, "batch_sampler"):
  296. assert isinstance(dl.batch_sampler.batch_sampler, CustomBatchSampler), (
  297. "Custom sampler was changed after calling `prepare_data_loader`"
  298. )
  299. else:
  300. assert isinstance(dl.batch_sampler, CustomBatchSampler), (
  301. "Custom sampler was changed after calling `prepare_data_loader`"
  302. )
  303. def check_seedable_sampler():
  304. # Set seed
  305. set_seed(42)
  306. train_set = RegressionDataset(length=10, seed=42)
  307. train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
  308. config = DataLoaderConfiguration(use_seedable_sampler=True)
  309. accelerator = Accelerator(dataloader_config=config)
  310. train_dl = accelerator.prepare(train_dl)
  311. original_items = []
  312. for _ in range(3):
  313. for batch in train_dl:
  314. original_items.append(batch["x"])
  315. original_items = torch.cat(original_items)
  316. # Set seed again and the epoch
  317. set_seed(42)
  318. train_dl.set_epoch(0)
  319. new_items = []
  320. for _ in range(3):
  321. for batch in train_dl:
  322. new_items.append(batch["x"])
  323. new_items = torch.cat(new_items)
  324. assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."
  325. def check_seedable_sampler_in_batch_sampler_shard():
  326. set_seed(42)
  327. config = DataLoaderConfiguration(use_seedable_sampler=True)
  328. accelerator = Accelerator(dataloader_config=config)
  329. assert accelerator.num_processes > 1, "This test requires more than one process."
  330. dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True)
  331. prepared_data_loader = prepare_data_loader(
  332. dataloader=dataloader,
  333. use_seedable_sampler=True,
  334. )
  335. target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler
  336. assert isinstance(target_sampler, SeedableRandomSampler), (
  337. "Sampler in BatchSamplerShard is not SeedableRandomSampler."
  338. )
  339. def check_seedable_sampler_with_data_seed():
  340. # Set seed
  341. set_seed(42)
  342. data_seed = 42
  343. train_set = RegressionDataset(length=10, seed=42)
  344. train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
  345. config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed)
  346. accelerator = Accelerator(dataloader_config=config)
  347. prepared_dl = accelerator.prepare(train_dl)
  348. original_items = []
  349. for _ in range(3):
  350. for batch in prepared_dl:
  351. original_items.append(batch["x"])
  352. original_items = torch.cat(original_items)
  353. # Set new data seed
  354. config.data_seed = 43
  355. accelerator = Accelerator(dataloader_config=config)
  356. prepared_dl = accelerator.prepare(train_dl)
  357. new_items = []
  358. for _ in range(3):
  359. for batch in prepared_dl:
  360. new_items.append(batch["x"])
  361. new_items = torch.cat(new_items)
  362. assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed."
  363. def mock_training(length, batch_size, generator, use_seedable_sampler=False):
  364. set_seed(42)
  365. generator.manual_seed(42)
  366. train_set = RegressionDataset(length=length, seed=42)
  367. train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
  368. model = RegressionModel()
  369. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  370. for epoch in range(3):
  371. for batch in train_dl:
  372. model.zero_grad()
  373. output = model(batch["x"])
  374. loss = torch.nn.functional.mse_loss(output, batch["y"])
  375. loss.backward()
  376. optimizer.step()
  377. return train_set, model
  378. def training_check(use_seedable_sampler=False):
  379. state = AcceleratorState()
  380. generator = torch.Generator()
  381. batch_size = 8
  382. length = batch_size * 4 * state.num_processes
  383. train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler)
  384. assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
  385. assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
  386. accelerator = Accelerator()
  387. train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
  388. model = RegressionModel()
  389. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  390. train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
  391. set_seed(42)
  392. generator.manual_seed(42)
  393. for _ in range(3):
  394. for batch in train_dl:
  395. model.zero_grad()
  396. output = model(batch["x"])
  397. loss = torch.nn.functional.mse_loss(output, batch["y"])
  398. accelerator.backward(loss)
  399. optimizer.step()
  400. model = accelerator.unwrap_model(model).cpu()
  401. torch.testing.assert_close(
  402. old_model.a,
  403. model.a,
  404. atol=ATOL,
  405. rtol=RTOL,
  406. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  407. )
  408. torch.testing.assert_close(
  409. old_model.b,
  410. model.b,
  411. atol=ATOL,
  412. rtol=RTOL,
  413. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  414. )
  415. accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
  416. dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)
  417. accelerator = Accelerator(dataloader_config=dataloader_config)
  418. train_dl = generate_baseline_dataloader(
  419. train_set, generator, batch_size * state.num_processes, use_seedable_sampler
  420. )
  421. model = RegressionModel()
  422. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  423. train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
  424. set_seed(42)
  425. generator.manual_seed(42)
  426. for _ in range(3):
  427. for batch in train_dl:
  428. model.zero_grad()
  429. output = model(batch["x"])
  430. loss = torch.nn.functional.mse_loss(output, batch["y"])
  431. accelerator.backward(loss)
  432. optimizer.step()
  433. model = accelerator.unwrap_model(model).cpu()
  434. torch.testing.assert_close(
  435. old_model.a,
  436. model.a,
  437. atol=ATOL,
  438. rtol=RTOL,
  439. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  440. )
  441. torch.testing.assert_close(
  442. old_model.b,
  443. model.b,
  444. atol=ATOL,
  445. rtol=RTOL,
  446. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  447. )
  448. accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
  449. # FP32 wrapper check
  450. if is_cuda_available() or is_mps_available():
  451. # Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
  452. print("Keep fp32 wrapper check.")
  453. AcceleratorState._reset_state()
  454. accelerator = Accelerator(mixed_precision="fp16")
  455. model = torch.nn.Linear(2, 4)
  456. model = accelerator.prepare(model)
  457. model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
  458. # Run forward with fp16 as input.
  459. # When the model is with mixed precision wrapper, no error will be raised.
  460. input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device)
  461. output = model_with_fp32_wrapper(input_tensor)
  462. # BF16 support
  463. if is_bf16_available():
  464. # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
  465. print("BF16 training check.")
  466. AcceleratorState._reset_state()
  467. dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
  468. accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config)
  469. train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
  470. model = RegressionModel()
  471. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  472. train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
  473. set_seed(42)
  474. generator.manual_seed(42)
  475. for _ in range(3):
  476. for batch in train_dl:
  477. model.zero_grad()
  478. output = model(batch["x"])
  479. loss = torch.nn.functional.mse_loss(output, batch["y"])
  480. accelerator.backward(loss)
  481. optimizer.step()
  482. model = accelerator.unwrap_model(model).cpu()
  483. torch.testing.assert_close(
  484. old_model.a,
  485. model.a,
  486. atol=ATOL,
  487. rtol=RTOL,
  488. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  489. )
  490. torch.testing.assert_close(
  491. old_model.b,
  492. model.b,
  493. atol=ATOL,
  494. rtol=RTOL,
  495. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  496. )
  497. # FP16 support (HPU fp16 model seems to be off by 10% from the CPU, which is a lot of numerical error)
  498. if is_fp16_available() and not is_hpu_available():
  499. # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
  500. print("FP16 training check.")
  501. AcceleratorState._reset_state()
  502. dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
  503. accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config)
  504. train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
  505. model = RegressionModel()
  506. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  507. train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
  508. set_seed(42)
  509. generator.manual_seed(42)
  510. for _ in range(3):
  511. for batch in train_dl:
  512. model.zero_grad()
  513. output = model(batch["x"])
  514. loss = torch.nn.functional.mse_loss(output, batch["y"])
  515. accelerator.backward(loss)
  516. optimizer.step()
  517. model = accelerator.unwrap_model(model).cpu()
  518. torch.testing.assert_close(
  519. old_model.a,
  520. model.a,
  521. atol=ATOL,
  522. rtol=RTOL,
  523. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  524. )
  525. torch.testing.assert_close(
  526. old_model.b,
  527. model.b,
  528. atol=ATOL,
  529. rtol=RTOL,
  530. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  531. )
  532. # IPEX CPU tests
  533. if is_ipex_available():
  534. print("ipex BF16 training check.")
  535. AcceleratorState._reset_state()
  536. dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
  537. accelerator = Accelerator(mixed_precision="bf16", cpu=True, dataloader_config=dataloader_config)
  538. train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
  539. model = RegressionModel()
  540. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  541. train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
  542. set_seed(42)
  543. generator.manual_seed(42)
  544. for _ in range(3):
  545. for batch in train_dl:
  546. model.zero_grad()
  547. output = model(batch["x"])
  548. loss = torch.nn.functional.mse_loss(output, batch["y"])
  549. accelerator.backward(loss)
  550. optimizer.step()
  551. model = accelerator.unwrap_model(model).cpu()
  552. torch.testing.assert_close(
  553. old_model.a,
  554. model.a,
  555. atol=ATOL,
  556. rtol=RTOL,
  557. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  558. )
  559. torch.testing.assert_close(
  560. old_model.b,
  561. model.b,
  562. atol=ATOL,
  563. rtol=RTOL,
  564. msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
  565. )
  566. def test_split_between_processes_dataset(datasets_Dataset):
  567. state = AcceleratorState()
  568. data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
  569. with state.split_between_processes(data, apply_padding=False) as results:
  570. assert len(results) == 2, (
  571. f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
  572. )
  573. data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
  574. with state.split_between_processes(data, apply_padding=False) as results:
  575. if state.is_last_process:
  576. assert len(results) == 1, (
  577. f"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}"
  578. )
  579. else:
  580. assert len(results) == 2, (
  581. f"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}"
  582. )
  583. state.wait_for_everyone()
  584. odd_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
  585. even_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
  586. for data in [odd_data, even_data]:
  587. expected_output = data["k"]
  588. with state.split_between_processes(data, apply_padding=True) as results:
  589. if state.num_processes == 1:
  590. assert len(results) == len(data), (
  591. f"Single process did not receive all items. Process index: {state.process_index}; Length: {len(results)}"
  592. )
  593. else:
  594. assert len(results) == 2, (
  595. f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
  596. )
  597. results_per_process = []
  598. for result in results:
  599. results_per_process.append(result)
  600. state.wait_for_everyone()
  601. gathered_results = gather_object(results_per_process)
  602. output = [r["k"] for r in gathered_results[: len(data)]]
  603. assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
  604. def test_split_between_processes_list():
  605. state = AcceleratorState()
  606. data = list(range(0, 2 * state.num_processes))
  607. with state.split_between_processes(data) as results:
  608. assert len(results) == 2, (
  609. f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
  610. )
  611. state.wait_for_everyone()
  612. even_data = list(range(0, (2 * state.num_processes)))
  613. odd_data = list(range(0, (2 * state.num_processes) - 1))
  614. for data in [odd_data, even_data]:
  615. expected_output = data
  616. with state.split_between_processes(data, apply_padding=True) as results:
  617. num_samples_per_device = math.ceil(len(data) / state.num_processes)
  618. # Test all processes gets the correct number of item(s)
  619. assert len(results) == num_samples_per_device, (
  620. f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
  621. )
  622. results_per_process = []
  623. for result in results:
  624. results_per_process.append(result)
  625. state.wait_for_everyone()
  626. gathered_results = gather_object(results_per_process)
  627. output = gathered_results[: len(data)]
  628. assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
  629. def test_split_between_processes_nested_dict():
  630. state = AcceleratorState()
  631. a = [1, 2, 3, 4, 5, 6, 7, 8]
  632. b = ["a", "b", "c", "d", "e", "f", "g", "h"]
  633. c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
  634. if state.num_processes in (1, 2, 4):
  635. data = {"a": a, "b": b, "c": c}
  636. data_copy = deepcopy(data)
  637. with state.split_between_processes(data) as results:
  638. if state.process_index == 0:
  639. assert results["a"] == data_copy["a"][: 8 // state.num_processes]
  640. elif state.num_processes == 2:
  641. assert results["a"] == data_copy["a"][4:]
  642. elif state.process_index == 3:
  643. # We return a list each time
  644. assert results["a"] == data_copy["a"][-2:], f"Expected: {data_copy['a'][-2]}, Actual: {results['a']}"
  645. if state.process_index == 0:
  646. assert results["b"] == data_copy["b"][: 8 // state.num_processes]
  647. elif state.num_processes == 2:
  648. assert results["b"] == data_copy["b"][4:]
  649. elif state.process_index == 3:
  650. assert results["b"] == data_copy["b"][-2:]
  651. if state.process_index == 0:
  652. assert torch.allclose(results["c"], data_copy["c"][: 8 // state.num_processes]), (
  653. f"Did not obtain expected values on process 0, expected `{data['c'][: 8 // state.num_processes]}`, received: {results['c']}"
  654. )
  655. elif state.num_processes == 2:
  656. assert torch.allclose(results["c"], data_copy["c"][4:]), (
  657. f"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}"
  658. )
  659. elif state.process_index == 3:
  660. assert torch.allclose(results["c"], data_copy["c"][-2:]), (
  661. f"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}"
  662. )
  663. state.wait_for_everyone()
  664. def test_split_between_processes_tensor():
  665. state = AcceleratorState()
  666. if state.num_processes > 1:
  667. data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device)
  668. with state.split_between_processes(data) as results:
  669. if state.process_index == 0:
  670. expected = torch.tensor([[0, 1, 2, 3]]).to(state.device)
  671. else:
  672. expected = torch.tensor([[4, 5, 6, 7]]).to(state.device)
  673. torch.testing.assert_close(results, expected)
  674. state.wait_for_everyone()
  675. even_data = torch.tensor([[i] for i in range(2 * state.num_processes)]).to(state.device)
  676. odd_data = torch.tensor([[i] for i in range(2 * state.num_processes - 1)]).to(state.device)
  677. for data in [even_data, odd_data]:
  678. expected_output = [torch.tensor(i) for i in data.tolist()]
  679. with state.split_between_processes(data, apply_padding=True) as results:
  680. num_samples_per_device = math.ceil(len(data) / state.num_processes)
  681. assert len(results) == num_samples_per_device, (
  682. f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
  683. )
  684. results_per_process = []
  685. for result in results:
  686. results_per_process.append(result.to("cpu"))
  687. state.wait_for_everyone()
  688. gathered_results = gather_object(results_per_process)
  689. output = gathered_results[: len(data)]
  690. assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
  691. def test_split_between_processes_evenly():
  692. state = AcceleratorState()
  693. if state.num_processes in (1, 2, 4, 8):
  694. data = list(range(17))
  695. num_samples_per_process = len(data) // state.num_processes
  696. num_extras = len(data) % state.num_processes
  697. with state.split_between_processes(data) as results:
  698. if state.process_index < num_extras:
  699. assert len(results) == num_samples_per_process + 1, (
  700. f"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}"
  701. )
  702. else:
  703. assert len(results) == num_samples_per_process, (
  704. f"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}"
  705. )
  706. state.wait_for_everyone()
  707. def test_trigger():
  708. accelerator = Accelerator()
  709. # should start with being false
  710. assert accelerator.check_trigger() is False
  711. # set a breakpoint on the main process
  712. if accelerator.is_main_process:
  713. accelerator.set_trigger()
  714. # check it's been activated across all processes
  715. # calls `all_reduce` and triggers a sync
  716. assert accelerator.check_trigger() is True
  717. # check it's been reset after the sync
  718. assert accelerator.check_trigger() is False
  719. def test_reinstantiated_state():
  720. import pytest
  721. AcceleratorState._reset_state()
  722. simple_model = torch.nn.Linear(1, 1)
  723. # First define an accelerator
  724. accelerator = Accelerator()
  725. # Then call `reset_state`, breaking the state existing in the accelerator
  726. AcceleratorState._reset_state()
  727. # Now try and prepare a simple model, should raise the custom error early
  728. with pytest.raises(AttributeError) as cm:
  729. accelerator.prepare(simple_model)
  730. assert "`AcceleratorState` object has no attribute" in str(cm.value.args[0])
  731. assert "This happens if `AcceleratorState._reset_state()`" in str(cm.value.args[0])
  732. def main():
  733. accelerator = Accelerator()
  734. state = accelerator.state
  735. if state.local_process_index == 0:
  736. print("**Initialization**")
  737. init_state_check()
  738. state.wait_for_everyone()
  739. if state.distributed_type == DistributedType.MULTI_GPU:
  740. num_processes_per_node = torch.cuda.device_count()
  741. else:
  742. num_processes_per_node = state.num_processes
  743. # We only run this test on non-multinode
  744. if num_processes_per_node == state.num_processes:
  745. if state.process_index == 0:
  746. print("\n**Test process execution**")
  747. process_execution_check()
  748. if state.process_index == 0:
  749. print("\n**Test split between processes as a list**")
  750. test_split_between_processes_list()
  751. if state.process_index == 0:
  752. print("\n**Test split between processes as a dict**")
  753. test_split_between_processes_nested_dict()
  754. if state.process_index == 0:
  755. print("\n**Test split between processes as a tensor**")
  756. test_split_between_processes_tensor()
  757. if state.process_index == 0:
  758. print("\n**Test split between processes evenly**")
  759. test_split_between_processes_evenly()
  760. if state.process_index == 0:
  761. print("\n**Test split between processes as a datasets.Dataset**")
  762. if is_datasets_available():
  763. from datasets import Dataset as datasets_Dataset
  764. test_split_between_processes_dataset(datasets_Dataset)
  765. else:
  766. print("Skipped because Hugging Face datasets is not available")
  767. if state.local_process_index == 0:
  768. print("\n**Test random number generator synchronization**")
  769. rng_sync_check()
  770. if state.local_process_index == 0:
  771. print("\n**DataLoader integration test**")
  772. dl_preparation_check()
  773. if state.distributed_type != DistributedType.XLA:
  774. central_dl_preparation_check()
  775. custom_sampler_check()
  776. check_seedable_sampler()
  777. check_seedable_sampler_with_data_seed()
  778. if state.num_processes > 1:
  779. check_seedable_sampler_in_batch_sampler_shard()
  780. # Trainings are not exactly the same in DeepSpeed and CPU mode
  781. if state.distributed_type == DistributedType.DEEPSPEED:
  782. return
  783. if state.local_process_index == 0:
  784. print("\n**Training integration test**")
  785. training_check(use_seedable_sampler=False)
  786. training_check(use_seedable_sampler=True)
  787. if state.local_process_index == 0:
  788. print("\n**Breakpoint trigger test**")
  789. test_trigger()
  790. if is_pytest_available():
  791. if state.local_process_index == 0:
  792. print("\n**Test reinstantiated state**")
  793. test_reinstantiated_state()
  794. state.destroy_process_group()
  795. if __name__ == "__main__":
  796. main()