| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947 |
- #!/usr/bin/env python
- # Copyright 2021 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import contextlib
- import io
- import math
- import time
- from copy import deepcopy
- from pathlib import Path
- import numpy as np
- import torch
- from torch.utils.data import DataLoader, Dataset
- from accelerate import Accelerator
- from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
- from accelerate.state import AcceleratorState
- from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
- from accelerate.utils import (
- DataLoaderConfiguration,
- DistributedType,
- gather,
- gather_object,
- is_bf16_available,
- is_cuda_available,
- is_datasets_available,
- is_fp16_available,
- is_hpu_available,
- is_ipex_available,
- is_mps_available,
- is_pytest_available,
- set_seed,
- synchronize_rng_states,
- )
- if is_hpu_available():
- ATOL = 1e-3
- RTOL = 1e-3
- else:
- ATOL = 1e-6
- RTOL = 1e-6
- def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False):
- "Creates a dataloader that can also use the `SeedableRandomSampler`"
- if use_seedable_sampler:
- # The SeedableRandomSampler is needed during distributed setups
- # for full reproducibility across processes with the `DataLoader`
- sampler = SeedableRandomSampler(
- generator=generator,
- data_source=train_set,
- num_samples=len(train_set),
- )
- return DataLoader(train_set, batch_size=batch_size, sampler=sampler)
- else:
- return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
- def print_main(state):
- print(f"Printing from the main process {state.process_index}")
- def print_local_main(state):
- print(f"Printing from the local main process {state.local_process_index}")
- def print_last(state):
- print(f"Printing from the last process {state.process_index}")
- def print_on(state, process_idx):
- print(f"Printing from process {process_idx}: {state.process_index}")
- def process_execution_check():
- accelerator = Accelerator()
- num_processes = accelerator.num_processes
- # Test main_process_first context manager
- path = Path("check_main_process_first.txt")
- with accelerator.main_process_first():
- if accelerator.is_main_process:
- time.sleep(0.1) # ensure main process takes longest
- with open(path, "a+") as f:
- f.write("Currently in the main process\n")
- else:
- with open(path, "a+") as f:
- f.write("Now on another process\n")
- accelerator.wait_for_everyone()
- if accelerator.is_main_process:
- with open(path) as f:
- text = "".join(f.readlines())
- try:
- assert text.startswith("Currently in the main process\n"), "Main process was not first"
- if num_processes > 1:
- assert text.endswith("Now on another process\n"), "Main process was not first"
- assert text.count("Now on another process\n") == accelerator.num_processes - 1, (
- f"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}"
- )
- except AssertionError:
- path.unlink()
- raise
- if accelerator.is_main_process and path.exists():
- path.unlink()
- accelerator.wait_for_everyone()
- # Test the decorators
- f = io.StringIO()
- with contextlib.redirect_stdout(f):
- accelerator.on_main_process(print_main)(accelerator.state)
- result = f.getvalue().rstrip()
- if accelerator.is_main_process:
- assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0"
- else:
- assert f.getvalue().rstrip() == "", f'{result} != ""'
- f.truncate(0)
- f.seek(0)
- with contextlib.redirect_stdout(f):
- accelerator.on_local_main_process(print_local_main)(accelerator.state)
- if accelerator.is_local_main_process:
- assert f.getvalue().rstrip() == "Printing from the local main process 0"
- else:
- assert f.getvalue().rstrip() == ""
- f.truncate(0)
- f.seek(0)
- with contextlib.redirect_stdout(f):
- accelerator.on_last_process(print_last)(accelerator.state)
- if accelerator.is_last_process:
- assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}"
- else:
- assert f.getvalue().rstrip() == ""
- f.truncate(0)
- f.seek(0)
- for process_idx in range(num_processes):
- with contextlib.redirect_stdout(f):
- accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx)
- if accelerator.process_index == process_idx:
- assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}"
- else:
- assert f.getvalue().rstrip() == ""
- f.truncate(0)
- f.seek(0)
- def init_state_check():
- # Test we can instantiate this twice in a row.
- state = AcceleratorState()
- if state.local_process_index == 0:
- print("Testing, testing. 1, 2, 3.")
- print(state)
- def rng_sync_check():
- state = AcceleratorState()
- synchronize_rng_states(["torch"])
- assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU."
- if state.distributed_type == DistributedType.MULTI_GPU:
- synchronize_rng_states(["cuda"])
- assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU."
- elif state.distributed_type == DistributedType.MULTI_XPU:
- synchronize_rng_states(["xpu"])
- assert are_the_same_tensors(torch.xpu.get_rng_state()), "RNG states improperly synchronized on XPU."
- generator = torch.Generator()
- synchronize_rng_states(["generator"], generator=generator)
- assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."
- if state.local_process_index == 0:
- print("All rng are properly synched.")
- def dl_preparation_check():
- state = AcceleratorState()
- length = 32 * state.num_processes
- dl = DataLoader(range(length), batch_size=8)
- dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result)
- assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
- dl = DataLoader(range(length), batch_size=8)
- dl = prepare_data_loader(
- dl,
- state.device,
- state.num_processes,
- state.process_index,
- put_on_device=True,
- split_batches=True,
- )
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result)
- assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
- if state.process_index == 0:
- print("Non-shuffled dataloader passing.")
- dl = DataLoader(range(length), batch_size=8, shuffle=True)
- dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result).tolist()
- result.sort()
- assert result == list(range(length)), "Wrong shuffled dataloader result."
- dl = DataLoader(range(length), batch_size=8, shuffle=True)
- dl = prepare_data_loader(
- dl,
- state.device,
- state.num_processes,
- state.process_index,
- put_on_device=True,
- split_batches=True,
- )
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result).tolist()
- result.sort()
- assert result == list(range(length)), "Wrong shuffled dataloader result."
- if state.local_process_index == 0:
- print("Shuffled dataloader passing.")
- def central_dl_preparation_check():
- state = AcceleratorState()
- length = 32 * state.num_processes
- dl = DataLoader(range(length), batch_size=8)
- dl = prepare_data_loader(
- dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
- )
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result)
- assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
- dl = DataLoader(range(length), batch_size=8)
- dl = prepare_data_loader(
- dl,
- state.device,
- state.num_processes,
- state.process_index,
- put_on_device=True,
- split_batches=True,
- dispatch_batches=True,
- )
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result)
- assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
- if state.process_index == 0:
- print("Non-shuffled central dataloader passing.")
- dl = DataLoader(range(length), batch_size=8, shuffle=True)
- dl = prepare_data_loader(
- dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
- )
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result).tolist()
- result.sort()
- assert result == list(range(length)), "Wrong shuffled dataloader result."
- dl = DataLoader(range(length), batch_size=8, shuffle=True)
- dl = prepare_data_loader(
- dl,
- state.device,
- state.num_processes,
- state.process_index,
- put_on_device=True,
- split_batches=True,
- dispatch_batches=True,
- )
- result = []
- for batch in dl:
- result.append(gather(batch))
- result = torch.cat(result).tolist()
- result.sort()
- assert result == list(range(length)), "Wrong shuffled dataloader result."
- if state.local_process_index == 0:
- print("Shuffled central dataloader passing.")
- def custom_sampler_check():
- state = AcceleratorState()
- class CustomDataset(Dataset):
- def __init__(self, data):
- self.data = data
- def __len__(self):
- return len(self.data)
- def __getitem__(self, index):
- return self.data[index]
- class CustomBatchSampler:
- def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
- self.batch_size = batch_size
- self.data_index = np.arange(dataset_length)
- self.shuffle = shuffle
- def __iter__(self):
- num_batches = len(self)
- if self.shuffle:
- index = np.random.permutation(self.data_index)
- else:
- index = self.data_index
- output = np.array_split(index, num_batches)
- yield from output
- def __len__(self):
- return math.ceil(len(self.data_index) / self.batch_size)
- dataset = CustomDataset(range(32 * state.num_processes))
- sampler = CustomBatchSampler(len(dataset), batch_size=8)
- dl = DataLoader(dataset, batch_sampler=sampler)
- dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
- # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
- if hasattr(dl.batch_sampler, "batch_sampler"):
- assert isinstance(dl.batch_sampler.batch_sampler, CustomBatchSampler), (
- "Custom sampler was changed after calling `prepare_data_loader`"
- )
- else:
- assert isinstance(dl.batch_sampler, CustomBatchSampler), (
- "Custom sampler was changed after calling `prepare_data_loader`"
- )
- def check_seedable_sampler():
- # Set seed
- set_seed(42)
- train_set = RegressionDataset(length=10, seed=42)
- train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
- config = DataLoaderConfiguration(use_seedable_sampler=True)
- accelerator = Accelerator(dataloader_config=config)
- train_dl = accelerator.prepare(train_dl)
- original_items = []
- for _ in range(3):
- for batch in train_dl:
- original_items.append(batch["x"])
- original_items = torch.cat(original_items)
- # Set seed again and the epoch
- set_seed(42)
- train_dl.set_epoch(0)
- new_items = []
- for _ in range(3):
- for batch in train_dl:
- new_items.append(batch["x"])
- new_items = torch.cat(new_items)
- assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."
- def check_seedable_sampler_in_batch_sampler_shard():
- set_seed(42)
- config = DataLoaderConfiguration(use_seedable_sampler=True)
- accelerator = Accelerator(dataloader_config=config)
- assert accelerator.num_processes > 1, "This test requires more than one process."
- dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True)
- prepared_data_loader = prepare_data_loader(
- dataloader=dataloader,
- use_seedable_sampler=True,
- )
- target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler
- assert isinstance(target_sampler, SeedableRandomSampler), (
- "Sampler in BatchSamplerShard is not SeedableRandomSampler."
- )
- def check_seedable_sampler_with_data_seed():
- # Set seed
- set_seed(42)
- data_seed = 42
- train_set = RegressionDataset(length=10, seed=42)
- train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
- config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed)
- accelerator = Accelerator(dataloader_config=config)
- prepared_dl = accelerator.prepare(train_dl)
- original_items = []
- for _ in range(3):
- for batch in prepared_dl:
- original_items.append(batch["x"])
- original_items = torch.cat(original_items)
- # Set new data seed
- config.data_seed = 43
- accelerator = Accelerator(dataloader_config=config)
- prepared_dl = accelerator.prepare(train_dl)
- new_items = []
- for _ in range(3):
- for batch in prepared_dl:
- new_items.append(batch["x"])
- new_items = torch.cat(new_items)
- assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed."
- def mock_training(length, batch_size, generator, use_seedable_sampler=False):
- set_seed(42)
- generator.manual_seed(42)
- train_set = RegressionDataset(length=length, seed=42)
- train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
- model = RegressionModel()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- for epoch in range(3):
- for batch in train_dl:
- model.zero_grad()
- output = model(batch["x"])
- loss = torch.nn.functional.mse_loss(output, batch["y"])
- loss.backward()
- optimizer.step()
- return train_set, model
- def training_check(use_seedable_sampler=False):
- state = AcceleratorState()
- generator = torch.Generator()
- batch_size = 8
- length = batch_size * 4 * state.num_processes
- train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler)
- assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
- assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
- accelerator = Accelerator()
- train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
- model = RegressionModel()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
- set_seed(42)
- generator.manual_seed(42)
- for _ in range(3):
- for batch in train_dl:
- model.zero_grad()
- output = model(batch["x"])
- loss = torch.nn.functional.mse_loss(output, batch["y"])
- accelerator.backward(loss)
- optimizer.step()
- model = accelerator.unwrap_model(model).cpu()
- torch.testing.assert_close(
- old_model.a,
- model.a,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- torch.testing.assert_close(
- old_model.b,
- model.b,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
- dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)
- accelerator = Accelerator(dataloader_config=dataloader_config)
- train_dl = generate_baseline_dataloader(
- train_set, generator, batch_size * state.num_processes, use_seedable_sampler
- )
- model = RegressionModel()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
- set_seed(42)
- generator.manual_seed(42)
- for _ in range(3):
- for batch in train_dl:
- model.zero_grad()
- output = model(batch["x"])
- loss = torch.nn.functional.mse_loss(output, batch["y"])
- accelerator.backward(loss)
- optimizer.step()
- model = accelerator.unwrap_model(model).cpu()
- torch.testing.assert_close(
- old_model.a,
- model.a,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- torch.testing.assert_close(
- old_model.b,
- model.b,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
- # FP32 wrapper check
- if is_cuda_available() or is_mps_available():
- # Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
- print("Keep fp32 wrapper check.")
- AcceleratorState._reset_state()
- accelerator = Accelerator(mixed_precision="fp16")
- model = torch.nn.Linear(2, 4)
- model = accelerator.prepare(model)
- model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
- # Run forward with fp16 as input.
- # When the model is with mixed precision wrapper, no error will be raised.
- input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device)
- output = model_with_fp32_wrapper(input_tensor)
- # BF16 support
- if is_bf16_available():
- # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
- print("BF16 training check.")
- AcceleratorState._reset_state()
- dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
- accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config)
- train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
- model = RegressionModel()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
- set_seed(42)
- generator.manual_seed(42)
- for _ in range(3):
- for batch in train_dl:
- model.zero_grad()
- output = model(batch["x"])
- loss = torch.nn.functional.mse_loss(output, batch["y"])
- accelerator.backward(loss)
- optimizer.step()
- model = accelerator.unwrap_model(model).cpu()
- torch.testing.assert_close(
- old_model.a,
- model.a,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- torch.testing.assert_close(
- old_model.b,
- model.b,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- # FP16 support (HPU fp16 model seems to be off by 10% from the CPU, which is a lot of numerical error)
- if is_fp16_available() and not is_hpu_available():
- # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
- print("FP16 training check.")
- AcceleratorState._reset_state()
- dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
- accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config)
- train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
- model = RegressionModel()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
- set_seed(42)
- generator.manual_seed(42)
- for _ in range(3):
- for batch in train_dl:
- model.zero_grad()
- output = model(batch["x"])
- loss = torch.nn.functional.mse_loss(output, batch["y"])
- accelerator.backward(loss)
- optimizer.step()
- model = accelerator.unwrap_model(model).cpu()
- torch.testing.assert_close(
- old_model.a,
- model.a,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- torch.testing.assert_close(
- old_model.b,
- model.b,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- # IPEX CPU tests
- if is_ipex_available():
- print("ipex BF16 training check.")
- AcceleratorState._reset_state()
- dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
- accelerator = Accelerator(mixed_precision="bf16", cpu=True, dataloader_config=dataloader_config)
- train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
- model = RegressionModel()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
- set_seed(42)
- generator.manual_seed(42)
- for _ in range(3):
- for batch in train_dl:
- model.zero_grad()
- output = model(batch["x"])
- loss = torch.nn.functional.mse_loss(output, batch["y"])
- accelerator.backward(loss)
- optimizer.step()
- model = accelerator.unwrap_model(model).cpu()
- torch.testing.assert_close(
- old_model.a,
- model.a,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- torch.testing.assert_close(
- old_model.b,
- model.b,
- atol=ATOL,
- rtol=RTOL,
- msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
- )
- def test_split_between_processes_dataset(datasets_Dataset):
- state = AcceleratorState()
- data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
- with state.split_between_processes(data, apply_padding=False) as results:
- assert len(results) == 2, (
- f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
- )
- data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
- with state.split_between_processes(data, apply_padding=False) as results:
- if state.is_last_process:
- assert len(results) == 1, (
- f"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}"
- )
- else:
- assert len(results) == 2, (
- f"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}"
- )
- state.wait_for_everyone()
- odd_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
- even_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
- for data in [odd_data, even_data]:
- expected_output = data["k"]
- with state.split_between_processes(data, apply_padding=True) as results:
- if state.num_processes == 1:
- assert len(results) == len(data), (
- f"Single process did not receive all items. Process index: {state.process_index}; Length: {len(results)}"
- )
- else:
- assert len(results) == 2, (
- f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
- )
- results_per_process = []
- for result in results:
- results_per_process.append(result)
- state.wait_for_everyone()
- gathered_results = gather_object(results_per_process)
- output = [r["k"] for r in gathered_results[: len(data)]]
- assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
- def test_split_between_processes_list():
- state = AcceleratorState()
- data = list(range(0, 2 * state.num_processes))
- with state.split_between_processes(data) as results:
- assert len(results) == 2, (
- f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
- )
- state.wait_for_everyone()
- even_data = list(range(0, (2 * state.num_processes)))
- odd_data = list(range(0, (2 * state.num_processes) - 1))
- for data in [odd_data, even_data]:
- expected_output = data
- with state.split_between_processes(data, apply_padding=True) as results:
- num_samples_per_device = math.ceil(len(data) / state.num_processes)
- # Test all processes gets the correct number of item(s)
- assert len(results) == num_samples_per_device, (
- f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
- )
- results_per_process = []
- for result in results:
- results_per_process.append(result)
- state.wait_for_everyone()
- gathered_results = gather_object(results_per_process)
- output = gathered_results[: len(data)]
- assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
- def test_split_between_processes_nested_dict():
- state = AcceleratorState()
- a = [1, 2, 3, 4, 5, 6, 7, 8]
- b = ["a", "b", "c", "d", "e", "f", "g", "h"]
- c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
- if state.num_processes in (1, 2, 4):
- data = {"a": a, "b": b, "c": c}
- data_copy = deepcopy(data)
- with state.split_between_processes(data) as results:
- if state.process_index == 0:
- assert results["a"] == data_copy["a"][: 8 // state.num_processes]
- elif state.num_processes == 2:
- assert results["a"] == data_copy["a"][4:]
- elif state.process_index == 3:
- # We return a list each time
- assert results["a"] == data_copy["a"][-2:], f"Expected: {data_copy['a'][-2]}, Actual: {results['a']}"
- if state.process_index == 0:
- assert results["b"] == data_copy["b"][: 8 // state.num_processes]
- elif state.num_processes == 2:
- assert results["b"] == data_copy["b"][4:]
- elif state.process_index == 3:
- assert results["b"] == data_copy["b"][-2:]
- if state.process_index == 0:
- assert torch.allclose(results["c"], data_copy["c"][: 8 // state.num_processes]), (
- f"Did not obtain expected values on process 0, expected `{data['c'][: 8 // state.num_processes]}`, received: {results['c']}"
- )
- elif state.num_processes == 2:
- assert torch.allclose(results["c"], data_copy["c"][4:]), (
- f"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}"
- )
- elif state.process_index == 3:
- assert torch.allclose(results["c"], data_copy["c"][-2:]), (
- f"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}"
- )
- state.wait_for_everyone()
- def test_split_between_processes_tensor():
- state = AcceleratorState()
- if state.num_processes > 1:
- data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device)
- with state.split_between_processes(data) as results:
- if state.process_index == 0:
- expected = torch.tensor([[0, 1, 2, 3]]).to(state.device)
- else:
- expected = torch.tensor([[4, 5, 6, 7]]).to(state.device)
- torch.testing.assert_close(results, expected)
- state.wait_for_everyone()
- even_data = torch.tensor([[i] for i in range(2 * state.num_processes)]).to(state.device)
- odd_data = torch.tensor([[i] for i in range(2 * state.num_processes - 1)]).to(state.device)
- for data in [even_data, odd_data]:
- expected_output = [torch.tensor(i) for i in data.tolist()]
- with state.split_between_processes(data, apply_padding=True) as results:
- num_samples_per_device = math.ceil(len(data) / state.num_processes)
- assert len(results) == num_samples_per_device, (
- f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
- )
- results_per_process = []
- for result in results:
- results_per_process.append(result.to("cpu"))
- state.wait_for_everyone()
- gathered_results = gather_object(results_per_process)
- output = gathered_results[: len(data)]
- assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
- def test_split_between_processes_evenly():
- state = AcceleratorState()
- if state.num_processes in (1, 2, 4, 8):
- data = list(range(17))
- num_samples_per_process = len(data) // state.num_processes
- num_extras = len(data) % state.num_processes
- with state.split_between_processes(data) as results:
- if state.process_index < num_extras:
- assert len(results) == num_samples_per_process + 1, (
- f"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}"
- )
- else:
- assert len(results) == num_samples_per_process, (
- f"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}"
- )
- state.wait_for_everyone()
- def test_trigger():
- accelerator = Accelerator()
- # should start with being false
- assert accelerator.check_trigger() is False
- # set a breakpoint on the main process
- if accelerator.is_main_process:
- accelerator.set_trigger()
- # check it's been activated across all processes
- # calls `all_reduce` and triggers a sync
- assert accelerator.check_trigger() is True
- # check it's been reset after the sync
- assert accelerator.check_trigger() is False
- def test_reinstantiated_state():
- import pytest
- AcceleratorState._reset_state()
- simple_model = torch.nn.Linear(1, 1)
- # First define an accelerator
- accelerator = Accelerator()
- # Then call `reset_state`, breaking the state existing in the accelerator
- AcceleratorState._reset_state()
- # Now try and prepare a simple model, should raise the custom error early
- with pytest.raises(AttributeError) as cm:
- accelerator.prepare(simple_model)
- assert "`AcceleratorState` object has no attribute" in str(cm.value.args[0])
- assert "This happens if `AcceleratorState._reset_state()`" in str(cm.value.args[0])
- def main():
- accelerator = Accelerator()
- state = accelerator.state
- if state.local_process_index == 0:
- print("**Initialization**")
- init_state_check()
- state.wait_for_everyone()
- if state.distributed_type == DistributedType.MULTI_GPU:
- num_processes_per_node = torch.cuda.device_count()
- else:
- num_processes_per_node = state.num_processes
- # We only run this test on non-multinode
- if num_processes_per_node == state.num_processes:
- if state.process_index == 0:
- print("\n**Test process execution**")
- process_execution_check()
- if state.process_index == 0:
- print("\n**Test split between processes as a list**")
- test_split_between_processes_list()
- if state.process_index == 0:
- print("\n**Test split between processes as a dict**")
- test_split_between_processes_nested_dict()
- if state.process_index == 0:
- print("\n**Test split between processes as a tensor**")
- test_split_between_processes_tensor()
- if state.process_index == 0:
- print("\n**Test split between processes evenly**")
- test_split_between_processes_evenly()
- if state.process_index == 0:
- print("\n**Test split between processes as a datasets.Dataset**")
- if is_datasets_available():
- from datasets import Dataset as datasets_Dataset
- test_split_between_processes_dataset(datasets_Dataset)
- else:
- print("Skipped because Hugging Face datasets is not available")
- if state.local_process_index == 0:
- print("\n**Test random number generator synchronization**")
- rng_sync_check()
- if state.local_process_index == 0:
- print("\n**DataLoader integration test**")
- dl_preparation_check()
- if state.distributed_type != DistributedType.XLA:
- central_dl_preparation_check()
- custom_sampler_check()
- check_seedable_sampler()
- check_seedable_sampler_with_data_seed()
- if state.num_processes > 1:
- check_seedable_sampler_in_batch_sampler_shard()
- # Trainings are not exactly the same in DeepSpeed and CPU mode
- if state.distributed_type == DistributedType.DEEPSPEED:
- return
- if state.local_process_index == 0:
- print("\n**Training integration test**")
- training_check(use_seedable_sampler=False)
- training_check(use_seedable_sampler=True)
- if state.local_process_index == 0:
- print("\n**Breakpoint trigger test**")
- test_trigger()
- if is_pytest_available():
- if state.local_process_index == 0:
- print("\n**Test reinstantiated state**")
- test_reinstantiated_state()
- state.destroy_process_group()
- if __name__ == "__main__":
- main()
|