data_loader.py 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import math
  16. from contextlib import suppress
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from packaging import version
  20. from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
  21. from .logging import get_logger
  22. from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
  23. from .utils import (
  24. RNGType,
  25. broadcast,
  26. broadcast_object_list,
  27. compare_versions,
  28. concatenate,
  29. find_batch_size,
  30. get_data_structure,
  31. initialize_tensors,
  32. is_datasets_available,
  33. is_torch_version,
  34. is_torchdata_stateful_dataloader_available,
  35. send_to_device,
  36. slice_tensors,
  37. synchronize_rng_states,
  38. )
  39. logger = get_logger(__name__)
  40. # kwargs of the DataLoader in min version 2.0
  41. _PYTORCH_DATALOADER_KWARGS = {
  42. "batch_size": 1,
  43. "shuffle": False,
  44. "sampler": None,
  45. "batch_sampler": None,
  46. "num_workers": 0,
  47. "collate_fn": None,
  48. "pin_memory": False,
  49. "drop_last": False,
  50. "timeout": 0,
  51. "worker_init_fn": None,
  52. "multiprocessing_context": None,
  53. "generator": None,
  54. "prefetch_factor": 2,
  55. "persistent_workers": False,
  56. "pin_memory_device": "",
  57. }
  58. # kwargs added after by version
  59. _PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}
  60. for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
  61. if is_torch_version(">=", v):
  62. _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
  63. class SeedableRandomSampler(RandomSampler):
  64. """
  65. Same as a random sampler, except that in `__iter__` a seed can be used.
  66. Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
  67. and be fully reproducible on multiple iterations.
  68. If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
  69. (stored in `self.epoch`).
  70. """
  71. def __init__(self, *args, **kwargs):
  72. data_seed = kwargs.pop("data_seed", None)
  73. super().__init__(*args, **kwargs)
  74. self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
  75. self.epoch = 0
  76. def __iter__(self):
  77. if self.generator is None:
  78. self.generator = torch.Generator(
  79. device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
  80. )
  81. self.generator.manual_seed(self.initial_seed)
  82. # Allow `self.epoch` to modify the seed of the generator
  83. seed = self.epoch + self.initial_seed
  84. # print("Setting seed at epoch", self.epoch, seed)
  85. self.generator.manual_seed(seed)
  86. yield from super().__iter__()
  87. self.set_epoch(self.epoch + 1)
  88. def set_epoch(self, epoch: int):
  89. "Sets the current iteration of the sampler."
  90. self.epoch = epoch
  91. class BatchSamplerShard(BatchSampler):
  92. """
  93. Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
  94. always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
  95. Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
  96. at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
  97. Args:
  98. batch_sampler (`torch.utils.data.sampler.BatchSampler`):
  99. The batch sampler to split in several shards.
  100. num_processes (`int`, *optional*, defaults to 1):
  101. The number of processes running concurrently.
  102. process_index (`int`, *optional*, defaults to 0):
  103. The index of the current process.
  104. split_batches (`bool`, *optional*, defaults to `False`):
  105. Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
  106. yielding different full batches on each process.
  107. On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
  108. - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
  109. this argument is set to `False`.
  110. - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
  111. then `[6, 7]` if this argument is set to `True`.
  112. even_batches (`bool`, *optional*, defaults to `True`):
  113. Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
  114. multiple of (original batch size / number of processes).
  115. <Tip warning={true}>
  116. `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
  117. equal to `False`
  118. </Tip>"""
  119. def __init__(
  120. self,
  121. batch_sampler: BatchSampler,
  122. num_processes: int = 1,
  123. process_index: int = 0,
  124. split_batches: bool = False,
  125. even_batches: bool = True,
  126. ):
  127. if split_batches and batch_sampler.batch_size % num_processes != 0:
  128. raise ValueError(
  129. f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
  130. f"needs to be a round multiple of the number of processes ({num_processes})."
  131. )
  132. self.batch_sampler = batch_sampler
  133. self.num_processes = num_processes
  134. self.process_index = process_index
  135. self.split_batches = split_batches
  136. self.even_batches = even_batches
  137. self.batch_size = getattr(batch_sampler, "batch_size", None)
  138. self.drop_last = getattr(batch_sampler, "drop_last", False)
  139. if self.batch_size is None and self.even_batches:
  140. raise ValueError(
  141. "You need to use `even_batches=False` when the batch sampler has no batch size. If you "
  142. "are not calling this method directly, set `accelerator.even_batches=False` instead."
  143. )
  144. @property
  145. def total_length(self):
  146. return len(self.batch_sampler)
  147. def __len__(self):
  148. if self.split_batches:
  149. # Split batches does not change the length of the batch sampler
  150. return len(self.batch_sampler)
  151. if len(self.batch_sampler) % self.num_processes == 0:
  152. # If the length is a round multiple of the number of processes, it's easy.
  153. return len(self.batch_sampler) // self.num_processes
  154. length = len(self.batch_sampler) // self.num_processes
  155. if self.drop_last:
  156. # Same if we drop the remainder.
  157. return length
  158. elif self.even_batches:
  159. # When we even batches we always get +1
  160. return length + 1
  161. else:
  162. # Otherwise it depends on the process index.
  163. return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
  164. def __iter__(self):
  165. return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
  166. def _iter_with_split(self):
  167. initial_data = []
  168. batch_length = self.batch_sampler.batch_size // self.num_processes
  169. for idx, batch in enumerate(self.batch_sampler):
  170. if idx == 0:
  171. initial_data = batch
  172. if len(batch) == self.batch_size:
  173. # If the batch is full, we yield the part of it this process is responsible of.
  174. yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
  175. # If drop_last is True of the last batch was full, iteration is over, otherwise...
  176. if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
  177. if not self.even_batches:
  178. if len(batch) > batch_length * self.process_index:
  179. yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
  180. else:
  181. # For degenerate cases where the dataset has less than num_process * batch_size samples
  182. while len(initial_data) < self.batch_size:
  183. initial_data += initial_data
  184. batch = batch + initial_data
  185. yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
  186. def _iter_with_no_split(self):
  187. initial_data = []
  188. batch_to_yield = []
  189. for idx, batch in enumerate(self.batch_sampler):
  190. # We gather the initial indices in case we need to circle back at the end.
  191. if not self.drop_last and idx < self.num_processes:
  192. initial_data += batch
  193. # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
  194. # yielding it.
  195. if idx % self.num_processes == self.process_index:
  196. batch_to_yield = batch
  197. if idx % self.num_processes == self.num_processes - 1 and (
  198. self.batch_size is None or len(batch) == self.batch_size
  199. ):
  200. yield batch_to_yield
  201. batch_to_yield = []
  202. # If drop_last is True, iteration is over, otherwise...
  203. if not self.drop_last and len(initial_data) > 0:
  204. if not self.even_batches:
  205. if len(batch_to_yield) > 0:
  206. yield batch_to_yield
  207. else:
  208. # ... we yield the complete batch we had saved before if it has the proper length
  209. if len(batch_to_yield) == self.batch_size:
  210. yield batch_to_yield
  211. # For degenerate cases where the dataset has less than num_process * batch_size samples
  212. while len(initial_data) < self.num_processes * self.batch_size:
  213. initial_data += initial_data
  214. # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
  215. if len(batch) == self.batch_size:
  216. batch = []
  217. idx += 1
  218. # Make sure we yield a multiple of self.num_processes batches
  219. cycle_index = 0
  220. while idx % self.num_processes != 0 or len(batch) > 0:
  221. end_index = cycle_index + self.batch_size - len(batch)
  222. batch += initial_data[cycle_index:end_index]
  223. if idx % self.num_processes == self.process_index:
  224. yield batch
  225. cycle_index = end_index
  226. batch = []
  227. idx += 1
  228. class IterableDatasetShard(IterableDataset):
  229. """
  230. Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
  231. always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
  232. `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
  233. `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
  234. be too small or loop with indices from the beginning.
  235. Args:
  236. dataset (`torch.utils.data.dataset.IterableDataset`):
  237. The batch sampler to split in several shards.
  238. batch_size (`int`, *optional*, defaults to 1):
  239. The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
  240. `split_batches=True`).
  241. drop_last (`bool`, *optional*, defaults to `False`):
  242. Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
  243. beginning.
  244. num_processes (`int`, *optional*, defaults to 1):
  245. The number of processes running concurrently.
  246. process_index (`int`, *optional*, defaults to 0):
  247. The index of the current process.
  248. split_batches (`bool`, *optional*, defaults to `False`):
  249. Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
  250. yielding different full batches on each process.
  251. On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
  252. - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
  253. argument is set to `False`.
  254. - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
  255. this argument is set to `True`.
  256. """
  257. def __init__(
  258. self,
  259. dataset: IterableDataset,
  260. batch_size: int = 1,
  261. drop_last: bool = False,
  262. num_processes: int = 1,
  263. process_index: int = 0,
  264. split_batches: bool = False,
  265. ):
  266. if split_batches and batch_size > 1 and batch_size % num_processes != 0:
  267. raise ValueError(
  268. f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
  269. f"needs to be a round multiple of the number of processes ({num_processes})."
  270. )
  271. self.dataset: IterableDataset = dataset
  272. self.batch_size = batch_size
  273. self.drop_last = drop_last
  274. self.num_processes = num_processes
  275. self.process_index = process_index
  276. self.split_batches = split_batches
  277. def set_epoch(self, epoch):
  278. self.epoch = epoch
  279. if hasattr(self.dataset, "set_epoch"):
  280. self.dataset.set_epoch(epoch)
  281. def __len__(self):
  282. # We will just raise the downstream error if the underlying dataset is not sized
  283. if self.drop_last:
  284. return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
  285. else:
  286. return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
  287. def __iter__(self):
  288. if (
  289. not hasattr(self.dataset, "set_epoch")
  290. and hasattr(self.dataset, "generator")
  291. and isinstance(self.dataset.generator, torch.Generator)
  292. ):
  293. self.dataset.generator.manual_seed(self.epoch)
  294. real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
  295. process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
  296. process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
  297. first_batch = None
  298. current_batch = []
  299. for element in self.dataset:
  300. current_batch.append(element)
  301. # Wait to have a full batch before yielding elements.
  302. if len(current_batch) == real_batch_size:
  303. for i in process_slice:
  304. yield current_batch[i]
  305. if first_batch is None:
  306. first_batch = current_batch.copy()
  307. current_batch = []
  308. # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
  309. if not self.drop_last and len(current_batch) > 0:
  310. if first_batch is None:
  311. first_batch = current_batch.copy()
  312. while len(current_batch) < real_batch_size:
  313. current_batch += first_batch
  314. for i in process_slice:
  315. yield current_batch[i]
  316. class DataLoaderStateMixin:
  317. """
  318. Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
  319. end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
  320. useful information that might be needed.
  321. **Available attributes:**
  322. - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
  323. - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
  324. batch size
  325. <Tip warning={true}>
  326. Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
  327. `self.gradient_state`.
  328. </Tip>
  329. """
  330. def __init_subclass__(cls, **kwargs):
  331. cls.end_of_dataloader = False
  332. cls.remainder = -1
  333. def reset(self):
  334. self.end_of_dataloader = False
  335. self.remainder = -1
  336. def begin(self):
  337. "Prepares the gradient state for the current dataloader"
  338. self.reset()
  339. with suppress(Exception):
  340. if not self._drop_last:
  341. length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
  342. self.remainder = length % self.total_batch_size
  343. self.gradient_state._add_dataloader(self)
  344. def end(self):
  345. "Cleans up the gradient state after exiting the dataloader"
  346. self.gradient_state._remove_dataloader(self)
  347. class DataLoaderAdapter:
  348. """
  349. A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
  350. compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
  351. """
  352. def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
  353. self.use_stateful_dataloader = use_stateful_dataloader
  354. if is_torchdata_stateful_dataloader_available():
  355. from torchdata.stateful_dataloader import StatefulDataLoader
  356. if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
  357. raise ImportError(
  358. "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
  359. )
  360. if use_stateful_dataloader:
  361. torchdata_version = version.parse(importlib.metadata.version("torchdata"))
  362. if (
  363. "in_order" in kwargs
  364. and compare_versions(torchdata_version, "<", "0.11")
  365. and is_torch_version(">=", "2.6.0")
  366. ):
  367. kwargs.pop("in_order")
  368. self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
  369. else:
  370. self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
  371. if hasattr(self.base_dataloader, "state_dict"):
  372. self.dl_state_dict = self.base_dataloader.state_dict()
  373. def __getattr__(self, name):
  374. # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
  375. if name == "base_dataloader":
  376. raise AttributeError()
  377. # Delegate attribute access to the internal dataloader
  378. return getattr(self.base_dataloader, name)
  379. def state_dict(self):
  380. return self.dl_state_dict
  381. def load_state_dict(self, state_dict):
  382. self.base_dataloader.load_state_dict(state_dict)
  383. @property
  384. def __class__(self):
  385. """
  386. In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`
  387. returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the
  388. object.
  389. """
  390. return self.base_dataloader.__class__
  391. def __len__(self):
  392. return len(self.base_dataloader)
  393. def adjust_state_dict_for_prefetch(self):
  394. """
  395. Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
  396. `self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
  397. overridden.
  398. This should modify `self.dl_state_dict` directly
  399. """
  400. # The state dict will be off by a factor of `n-1` batch too many during DDP,
  401. # so we need to adjust it here
  402. if PartialState().distributed_type != DistributedType.NO:
  403. factor = PartialState().num_processes - 1
  404. if self.dl_state_dict["_sampler_iter_yielded"] > 0:
  405. self.dl_state_dict["_sampler_iter_yielded"] -= factor
  406. if self.dl_state_dict["_num_yielded"] > 0:
  407. self.dl_state_dict["_num_yielded"] -= factor
  408. if self.dl_state_dict["_index_sampler_state"] is not None:
  409. if (
  410. "samples_yielded" in self.dl_state_dict["_index_sampler_state"]
  411. and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
  412. ):
  413. self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
  414. def _update_state_dict(self):
  415. # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
  416. # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
  417. # what it wants to yield.
  418. #
  419. # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
  420. if hasattr(self.base_dataloader, "state_dict"):
  421. self.dl_state_dict = self.base_dataloader.state_dict()
  422. # Potentially modify the state_dict to adjust for prefetching
  423. self.adjust_state_dict_for_prefetch()
  424. # Then tag if we are at the end of the dataloader
  425. self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
  426. class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
  427. """
  428. Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
  429. Args:
  430. dataset (`torch.utils.data.dataset.Dataset`):
  431. The dataset to use to build this dataloader.
  432. device (`torch.device`, *optional*):
  433. If passed, the device to put all batches on.
  434. rng_types (list of `str` or [`~utils.RNGType`]):
  435. The list of random number generators to synchronize at the beginning of each iteration. Should be one or
  436. several of:
  437. - `"torch"`: the base torch random number generator
  438. - `"cuda"`: the CUDA random number generator (GPU only)
  439. - `"xla"`: the XLA random number generator (TPU only)
  440. - `"generator"`: an optional `torch.Generator`
  441. synchronized_generator (`torch.Generator`, *optional*):
  442. A random number generator to keep synchronized across processes.
  443. skip_batches (`int`, *optional*, defaults to 0):
  444. The number of batches to skip at the beginning.
  445. use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
  446. Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
  447. **kwargs (additional keyword arguments, *optional*):
  448. All other keyword arguments to pass to the regular `DataLoader` initialization.
  449. **Available attributes:**
  450. - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
  451. Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
  452. number of processes
  453. - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
  454. """
  455. def __init__(
  456. self,
  457. dataset,
  458. device=None,
  459. rng_types=None,
  460. synchronized_generator=None,
  461. skip_batches=0,
  462. use_stateful_dataloader=False,
  463. _drop_last: bool = False,
  464. _non_blocking: bool = False,
  465. torch_device_mesh=None,
  466. **kwargs,
  467. ):
  468. super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
  469. self.device = device
  470. self.rng_types = rng_types
  471. self.synchronized_generator = synchronized_generator
  472. self.skip_batches = skip_batches
  473. self.gradient_state = GradientState()
  474. self._drop_last = _drop_last
  475. self._non_blocking = _non_blocking
  476. self.iteration = 0
  477. def __iter__(self):
  478. if self.rng_types is not None:
  479. synchronize_rng_states(self.rng_types, self.synchronized_generator)
  480. self.begin()
  481. self.set_epoch(self.iteration)
  482. dataloader_iter = self.base_dataloader.__iter__()
  483. # We iterate one batch ahead to check when we are at the end
  484. try:
  485. current_batch = next(dataloader_iter)
  486. except StopIteration:
  487. self.end()
  488. return
  489. batch_index = 0
  490. while True:
  491. try:
  492. # But we still move it to the device so it is done before `StopIteration` is reached
  493. if self.device is not None:
  494. current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
  495. self._update_state_dict()
  496. next_batch = next(dataloader_iter)
  497. if batch_index >= self.skip_batches:
  498. yield current_batch
  499. batch_index += 1
  500. current_batch = next_batch
  501. except StopIteration:
  502. self.end_of_dataloader = True
  503. self._update_state_dict()
  504. if batch_index >= self.skip_batches:
  505. yield current_batch
  506. break
  507. self.iteration += 1
  508. self.end()
  509. def __reduce__(self):
  510. """
  511. Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
  512. explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
  513. `__class__` member.
  514. """
  515. args = super().__reduce__()
  516. return (DataLoaderShard, *args[1:])
  517. def set_epoch(self, epoch: int):
  518. # In case it is manually passed in, the user can set it to what they like
  519. if self.iteration != epoch:
  520. self.iteration = epoch
  521. if hasattr(self.batch_sampler, "set_epoch"):
  522. self.batch_sampler.set_epoch(epoch)
  523. if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
  524. self.batch_sampler.sampler.set_epoch(epoch)
  525. if (
  526. hasattr(self.batch_sampler, "batch_sampler")
  527. and hasattr(self.batch_sampler.batch_sampler, "sampler")
  528. and hasattr(self.batch_sampler.batch_sampler.sampler, "set_epoch")
  529. ):
  530. self.batch_sampler.batch_sampler.sampler.set_epoch(epoch)
  531. # We support if a custom `Dataset` implementation has `set_epoch`
  532. # or in general HF datasets `Datasets`
  533. elif hasattr(self.dataset, "set_epoch"):
  534. self.dataset.set_epoch(epoch)
  535. @property
  536. def total_batch_size(self):
  537. batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
  538. return (
  539. batch_sampler.batch_size
  540. if getattr(batch_sampler, "split_batches", False)
  541. else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
  542. )
  543. @property
  544. def total_dataset_length(self):
  545. if hasattr(self.dataset, "total_length"):
  546. return self.dataset.total_length
  547. else:
  548. return len(self.dataset)
  549. def get_sampler(self):
  550. return get_sampler(self)
  551. def set_sampler(self, sampler):
  552. sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
  553. if sampler_is_batch_sampler:
  554. self.sampler.sampler = sampler
  555. else:
  556. self.batch_sampler.sampler = sampler
  557. if hasattr(self.batch_sampler, "batch_sampler"):
  558. self.batch_sampler.batch_sampler.sampler = sampler
  559. if is_torch_xla_available():
  560. import torch_xla.distributed.parallel_loader as xpl
  561. class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
  562. """
  563. Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
  564. XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
  565. prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
  566. thread only.
  567. **Available attributes:**
  568. - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
  569. Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
  570. number of processes
  571. - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
  572. """
  573. def __init__(self, dataloader: DataLoaderShard, device: torch.device):
  574. super().__init__(dataloader, device)
  575. self._rng_types = self._loader.rng_types
  576. self._loader.rng_types = None
  577. self.device = device
  578. def __iter__(self):
  579. if self._rng_types is not None:
  580. synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
  581. return super().__iter__()
  582. def set_epoch(self, epoch: int):
  583. if hasattr(self.dataloader, "set_epoch"):
  584. self.dataloader.set_epoch(epoch)
  585. @property
  586. def total_batch_size(self):
  587. return self._loader.total_batch_size
  588. @property
  589. def total_dataset_length(self):
  590. return self._loader.total_dataset_length
  591. @property
  592. def batch_sampler(self):
  593. return self._loader.batch_sampler
  594. @property
  595. def dataloader(self):
  596. return self._loader
  597. class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
  598. """
  599. Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
  600. their part of the batch.
  601. Args:
  602. split_batches (`bool`, *optional*, defaults to `False`):
  603. Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
  604. yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
  605. `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
  606. the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
  607. `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
  608. size of the `dataloader` is a round multiple of `batch_size`.
  609. skip_batches (`int`, *optional*, defaults to 0):
  610. The number of batches to skip at the beginning of an iteration.
  611. use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
  612. Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
  613. **Available attributes:**
  614. - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
  615. Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
  616. number of processes
  617. - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
  618. """
  619. def __init__(
  620. self,
  621. dataset,
  622. split_batches: bool = False,
  623. skip_batches=0,
  624. use_stateful_dataloader=False,
  625. _drop_last: bool = False,
  626. _non_blocking: bool = False,
  627. slice_fn=None,
  628. torch_device_mesh=None,
  629. **kwargs,
  630. ):
  631. shuffle = False
  632. from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
  633. # We need to save the shuffling state of the DataPipe
  634. if isinstance(dataset, ShufflerIterDataPipe):
  635. shuffle = dataset._shuffle_enabled
  636. super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
  637. self.split_batches = split_batches
  638. if shuffle:
  639. torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
  640. self.gradient_state = GradientState()
  641. self.state = PartialState()
  642. self._drop_last = _drop_last
  643. self._non_blocking = _non_blocking
  644. self.skip_batches = skip_batches
  645. self.torch_device_mesh = torch_device_mesh
  646. self.slice_fn = slice_tensors if slice_fn is None else slice_fn
  647. self.iteration = 0
  648. # if a device mesh is provided extract each dimension (dp, fsdp, tp)
  649. # device mesh may hold any number of dimensions, however,
  650. # below code is for targeted support for dp, fsdp and tp
  651. # device mesh will be used only if there is tp involved
  652. # or any multi-dimensional parallelism involving tp
  653. # (dp, tp) (fsdp, tp) (dp, fsdp, tp)
  654. # otherwise the default behaviour not using device mesh should be sufficient
  655. # since multi dimensional parallelism devoid of tp would anyway need
  656. # different batches for each process irrespective of dp or fsdp
  657. self.submesh_tp = None
  658. self.submesh_dp = None
  659. self.submesh_fsdp = None
  660. if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
  661. self.submesh_tp = self.torch_device_mesh["tp"]
  662. if "dp" in self.torch_device_mesh.mesh_dim_names:
  663. self.submesh_dp = self.torch_device_mesh["dp"]
  664. if "fsdp" in self.torch_device_mesh.mesh_dim_names:
  665. self.submesh_fsdp = self.torch_device_mesh["fsdp"]
  666. if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
  667. raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")
  668. def _fetch_batches(self, iterator):
  669. batches, batch = None, None
  670. # On process 0, we gather the batch to dispatch.
  671. if self.state.process_index == 0:
  672. # Procedure to support TP only is simpler
  673. # since we want to dispatch the same batch of samples across all ranks
  674. # this removes complexity of handling multiple tp rank groups when TP + DP
  675. # combination is involved.
  676. try:
  677. # for TP case avoid using split_batches
  678. # since it would mean that the dataloader should be spilling out
  679. # duplicates of batches.
  680. if self.split_batches:
  681. # One batch of the main iterator is dispatched and split.
  682. if self.submesh_tp:
  683. logger.warning(
  684. "Use of split_batches for TP would need the dataloader to produce duplicate batches,"
  685. "otherwise, use dispatch_batches=True instead."
  686. )
  687. self._update_state_dict()
  688. batch = next(iterator)
  689. else:
  690. # num_processes batches of the main iterator are concatenated then dispatched and split.
  691. # We add the batches one by one so we have the remainder available when drop_last=False.
  692. batches = []
  693. if self.submesh_tp:
  694. # when tp, extract single batch and then replicate
  695. self._update_state_dict()
  696. batch = next(iterator)
  697. batches = [batch] * self.state.num_processes
  698. else:
  699. for _ in range(self.state.num_processes):
  700. self._update_state_dict()
  701. batches.append(next(iterator))
  702. try:
  703. batch = concatenate(batches, dim=0)
  704. except RuntimeError as e:
  705. raise RuntimeError(
  706. "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
  707. "either pass `dispatch_batches=False` and have each process fetch its own batch "
  708. " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
  709. "slice it into `num_processes` batches for each process."
  710. ) from e
  711. # In both cases, we need to get the structure of the batch that we will broadcast on other
  712. # processes to initialize the tensors with the right shape.
  713. # data_structure, stop_iteration
  714. batch_info = [get_data_structure(batch), False]
  715. except StopIteration:
  716. batch_info = [None, True]
  717. else:
  718. batch_info = [None, self._stop_iteration]
  719. # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
  720. broadcast_object_list(batch_info)
  721. self._stop_iteration = batch_info[1]
  722. if self._stop_iteration:
  723. # If drop_last is False and split_batches is False, we may have a remainder to take care of.
  724. if not self.split_batches and not self._drop_last:
  725. if self.state.process_index == 0 and len(batches) > 0:
  726. batch = concatenate(batches, dim=0)
  727. batch_info = [get_data_structure(batch), False]
  728. else:
  729. batch_info = [None, True]
  730. broadcast_object_list(batch_info)
  731. return batch, batch_info
  732. def __iter__(self):
  733. self.begin()
  734. self.set_epoch(self.iteration)
  735. main_iterator = None
  736. if is_torch_version(">=", "2.0.1"):
  737. # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
  738. # shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
  739. # But, we only iterate through the DataLoader on process 0.
  740. main_iterator = self.base_dataloader.__iter__()
  741. elif self.state.process_index == 0:
  742. main_iterator = self.base_dataloader.__iter__()
  743. stop_iteration = False
  744. self._stop_iteration = False
  745. first_batch = None
  746. next_batch, next_batch_info = self._fetch_batches(main_iterator)
  747. batch_index = 0
  748. while not stop_iteration:
  749. batch, batch_info = next_batch, next_batch_info
  750. if self.state.process_index != 0:
  751. # Initialize tensors on other processes than process 0.
  752. batch = initialize_tensors(batch_info[0])
  753. batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
  754. # Broadcast the batch before splitting it.
  755. batch = broadcast(batch, from_process=0)
  756. if not self._drop_last and first_batch is None:
  757. # We keep at least num processes elements of the first batch to be able to complete the last batch
  758. first_batch = self.slice_fn(
  759. batch,
  760. slice(0, self.state.num_processes),
  761. process_index=self.state.process_index,
  762. num_processes=self.state.num_processes,
  763. )
  764. if batch is None:
  765. raise ValueError(
  766. f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
  767. )
  768. observed_batch_size = find_batch_size(batch)
  769. batch_size = observed_batch_size // self.state.num_processes
  770. stop_iteration = self._stop_iteration
  771. if not stop_iteration:
  772. # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
  773. # the dataloader since the number of batches is a round multiple of the number of processes.
  774. next_batch, next_batch_info = self._fetch_batches(main_iterator)
  775. # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
  776. if self._stop_iteration and next_batch_info[0] is None:
  777. stop_iteration = True
  778. if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
  779. # If the last batch is not complete, let's add the first batch to it.
  780. batch = concatenate([batch, first_batch], dim=0)
  781. # Batch size computation above is wrong, it's off by 1 so we fix it.
  782. batch_size += 1
  783. data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
  784. batch = self.slice_fn(
  785. batch,
  786. data_slice,
  787. process_index=self.state.process_index,
  788. num_processes=self.state.num_processes,
  789. )
  790. if stop_iteration:
  791. self.end_of_dataloader = True
  792. self._update_state_dict()
  793. self.remainder = observed_batch_size
  794. if batch_index >= self.skip_batches:
  795. yield batch
  796. batch_index += 1
  797. self.iteration += 1
  798. self.end()
  799. def set_epoch(self, epoch: int):
  800. # In case it is manually passed in, the user can set it to what they like
  801. if self.iteration != epoch:
  802. self.iteration = epoch
  803. if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
  804. self.batch_sampler.sampler.set_epoch(epoch)
  805. elif hasattr(self.dataset, "set_epoch"):
  806. self.dataset.set_epoch(epoch)
  807. def __len__(self):
  808. whole_length = len(self.base_dataloader)
  809. if self.split_batches:
  810. return whole_length
  811. elif self._drop_last:
  812. return whole_length // self.state.num_processes
  813. else:
  814. return math.ceil(whole_length / self.state.num_processes)
  815. def __reduce__(self):
  816. """
  817. Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
  818. be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
  819. `__class__` member.
  820. """
  821. args = super().__reduce__()
  822. return (DataLoaderDispatcher, *args[1:])
  823. @property
  824. def total_batch_size(self):
  825. return (
  826. self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
  827. )
  828. @property
  829. def total_dataset_length(self):
  830. return len(self.dataset)
  831. def get_sampler(self):
  832. return get_sampler(self)
  833. def set_sampler(self, sampler):
  834. sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
  835. if sampler_is_batch_sampler:
  836. self.sampler.sampler = sampler
  837. else:
  838. self.batch_sampler.sampler = sampler
  839. if hasattr(self.batch_sampler, "batch_sampler"):
  840. self.batch_sampler.batch_sampler.sampler = sampler
  841. def get_sampler(dataloader):
  842. """
  843. Get the sampler associated to the dataloader
  844. Args:
  845. dataloader (`torch.utils.data.dataloader.DataLoader`):
  846. The data loader to split across several devices.
  847. Returns:
  848. `torch.utils.data.Sampler`: The sampler associated to the dataloader
  849. """
  850. sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
  851. if sampler_is_batch_sampler:
  852. sampler = getattr(dataloader.sampler, "sampler", None)
  853. else:
  854. sampler = getattr(dataloader.batch_sampler, "sampler", None)
  855. return sampler
  856. def prepare_data_loader(
  857. dataloader: DataLoader,
  858. device: Optional[torch.device] = None,
  859. num_processes: Optional[int] = None,
  860. process_index: Optional[int] = None,
  861. split_batches: bool = False,
  862. put_on_device: bool = False,
  863. rng_types: Optional[list[Union[str, RNGType]]] = None,
  864. dispatch_batches: Optional[bool] = None,
  865. even_batches: bool = True,
  866. slice_fn_for_dispatch: Optional[Callable] = None,
  867. use_seedable_sampler: bool = False,
  868. data_seed: Optional[int] = None,
  869. non_blocking: bool = False,
  870. use_stateful_dataloader: bool = False,
  871. torch_device_mesh=None,
  872. ) -> DataLoader:
  873. """
  874. Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
  875. Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
  876. at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
  877. Args:
  878. dataloader (`torch.utils.data.dataloader.DataLoader`):
  879. The data loader to split across several devices.
  880. device (`torch.device`):
  881. The target device for the returned `DataLoader`.
  882. num_processes (`int`, *optional*):
  883. The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
  884. process_index (`int`, *optional*):
  885. The index of the current process. Will default to the value given by [`~state.PartialState`].
  886. split_batches (`bool`, *optional*, defaults to `False`):
  887. Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
  888. yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
  889. `num_processes` batches at each iteration).
  890. Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
  891. this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
  892. otherwise.
  893. Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
  894. `batch_size`.
  895. put_on_device (`bool`, *optional*, defaults to `False`):
  896. Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
  897. dictionaries of tensors).
  898. rng_types (list of `str` or [`~utils.RNGType`]):
  899. The list of random number generators to synchronize at the beginning of each iteration. Should be one or
  900. several of:
  901. - `"torch"`: the base torch random number generator
  902. - `"cuda"`: the CUDA random number generator (GPU only)
  903. - `"xla"`: the XLA random number generator (TPU only)
  904. - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
  905. dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
  906. dispatch_batches (`bool`, *optional*):
  907. If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
  908. are split and broadcast to each process. Will default to `True` when the underlying dataset is an
  909. `IterableDataset`, `False` otherwise.
  910. even_batches (`bool`, *optional*, defaults to `True`):
  911. If set to `True`, in cases where the total batch size across all processes does not exactly divide the
  912. dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
  913. all workers.
  914. slice_fn_for_dispatch (`Callable`, *optional*`):
  915. If passed, this function will be used to slice tensors across `num_processes`. Will default to
  916. [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
  917. ignored otherwise.
  918. use_seedable_sampler (`bool`, *optional*, defaults to `False`):
  919. Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
  920. reproducibility. Comes at a cost of potentially different performances due to different shuffling
  921. algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
  922. `self.set_epoch`
  923. data_seed (`int`, *optional*, defaults to `None`):
  924. The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
  925. will use the current default seed from torch.
  926. non_blocking (`bool`, *optional*, defaults to `False`):
  927. If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
  928. `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
  929. use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
  930. "If set to true, the dataloader prepared by the Accelerator will be backed by "
  931. "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
  932. This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
  933. torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
  934. PyTorch device mesh.
  935. Returns:
  936. `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
  937. <Tip warning={true}>
  938. `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
  939. equal to `False`
  940. </Tip>
  941. """
  942. if dispatch_batches is None:
  943. if not put_on_device:
  944. dispatch_batches = False
  945. else:
  946. dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
  947. if dispatch_batches and not put_on_device:
  948. raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
  949. # Grab defaults from PartialState
  950. state = PartialState()
  951. if num_processes is None:
  952. num_processes = state.num_processes
  953. if process_index is None:
  954. process_index = state.process_index
  955. if torch_device_mesh:
  956. if state.distributed_type == DistributedType.DEEPSPEED:
  957. # In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
  958. # Only considers "dp" and "tp".
  959. # Given a device mesh (dp, tp) = (2, 3):
  960. # - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
  961. # - Processes with the same DP rank will receive the same batch.
  962. submesh_tp_size = 1
  963. if "tp" in torch_device_mesh.mesh_dim_names:
  964. submesh_tp_size = torch_device_mesh["tp"].size()
  965. process_index = process_index // submesh_tp_size
  966. num_processes = num_processes // submesh_tp_size
  967. else:
  968. # when device mesh is used, specifically with TP
  969. # then there is need to update process_index and num_processes
  970. # to bring in the effect of generating same batch across TP ranks
  971. # and different batch across FSDP and DP ranks.
  972. # Example:
  973. # if device mesh is (dp,fsdp,tp) = (2, 2, 3)
  974. # ranks would range from 0...11
  975. # from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
  976. # processes with same ranks/ids would receive the same batch
  977. # for CP the same as TP applies
  978. submesh_fsdp_size = 1
  979. submesh_dp_size = 1
  980. submesh_tp_size = 1
  981. submesh_cp_size = 1
  982. if "tp" in torch_device_mesh.mesh_dim_names:
  983. submesh_tp_size = torch_device_mesh["tp"].size()
  984. if "cp" in torch_device_mesh.mesh_dim_names:
  985. submesh_cp_size = torch_device_mesh["cp"].size()
  986. if "dp_replicate" in torch_device_mesh.mesh_dim_names:
  987. submesh_dp_size = torch_device_mesh["dp_replicate"].size()
  988. if "dp_shard" in torch_device_mesh.mesh_dim_names:
  989. submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
  990. process_index = process_index // (submesh_tp_size * submesh_cp_size)
  991. num_processes = submesh_fsdp_size * submesh_dp_size
  992. # Sanity check
  993. if split_batches:
  994. if dataloader.batch_size is not None:
  995. batch_size_for_check = dataloader.batch_size
  996. else:
  997. # For custom batch_sampler
  998. if hasattr(dataloader.batch_sampler, "batch_size"):
  999. batch_size_for_check = dataloader.batch_sampler.batch_size
  1000. else:
  1001. raise ValueError(
  1002. "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
  1003. "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
  1004. "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
  1005. f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
  1006. )
  1007. if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
  1008. raise ValueError(
  1009. f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
  1010. f"needs to be a round multiple of the number of processes ({num_processes})."
  1011. )
  1012. new_dataset = dataloader.dataset
  1013. # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
  1014. new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
  1015. sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
  1016. synchronized_generator = None
  1017. sampler = get_sampler(dataloader)
  1018. if isinstance(sampler, RandomSampler) and use_seedable_sampler:
  1019. # When iterating through the dataloader during distributed processes
  1020. # we want to ensure that on each process we are iterating through the same
  1021. # samples in the same order if a seed is set. This requires a tweak
  1022. # to the `torch.utils.data.RandomSampler` class (if used).
  1023. sampler = SeedableRandomSampler(
  1024. data_source=sampler.data_source,
  1025. replacement=sampler.replacement,
  1026. num_samples=sampler._num_samples,
  1027. generator=getattr(
  1028. sampler,
  1029. "generator",
  1030. torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
  1031. ),
  1032. data_seed=data_seed,
  1033. )
  1034. if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
  1035. # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
  1036. generator = torch.Generator(
  1037. device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
  1038. )
  1039. seed = int(torch.empty((), dtype=torch.int64).random_().item())
  1040. generator.manual_seed(seed)
  1041. dataloader.generator = generator
  1042. dataloader.sampler.generator = generator
  1043. # No change if no multiprocess
  1044. if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
  1045. if is_datasets_available():
  1046. from datasets import IterableDataset as DatasetsIterableDataset
  1047. if (
  1048. is_datasets_available()
  1049. and isinstance(new_dataset, DatasetsIterableDataset)
  1050. and not split_batches
  1051. and new_dataset.n_shards > num_processes
  1052. ):
  1053. new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
  1054. elif isinstance(new_dataset, IterableDataset):
  1055. if getattr(dataloader.dataset, "generator", None) is not None:
  1056. synchronized_generator = dataloader.dataset.generator
  1057. new_dataset = IterableDatasetShard(
  1058. new_dataset,
  1059. batch_size=dataloader.batch_size,
  1060. drop_last=dataloader.drop_last,
  1061. num_processes=num_processes,
  1062. process_index=process_index,
  1063. split_batches=split_batches,
  1064. )
  1065. else:
  1066. if not use_seedable_sampler and hasattr(sampler, "generator"):
  1067. if sampler.generator is None:
  1068. sampler.generator = torch.Generator(
  1069. device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
  1070. )
  1071. seed = int(torch.empty((), dtype=torch.int64).random_().item())
  1072. sampler.generator.manual_seed(seed)
  1073. synchronized_generator = sampler.generator
  1074. batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
  1075. new_batch_sampler = BatchSamplerShard(
  1076. batch_sampler,
  1077. num_processes=num_processes,
  1078. process_index=process_index,
  1079. split_batches=split_batches,
  1080. even_batches=even_batches,
  1081. )
  1082. # We ignore all of those since they are all dealt with by our new_batch_sampler
  1083. ignore_kwargs = [
  1084. "batch_size",
  1085. "shuffle",
  1086. "sampler",
  1087. "batch_sampler",
  1088. "drop_last",
  1089. ]
  1090. if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
  1091. rng_types.remove("generator")
  1092. kwargs = {
  1093. k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
  1094. for k in _PYTORCH_DATALOADER_KWARGS
  1095. if k not in ignore_kwargs
  1096. }
  1097. # Need to provide batch_size as batch_sampler is None for Iterable dataset
  1098. if new_batch_sampler is None:
  1099. kwargs["drop_last"] = dataloader.drop_last
  1100. kwargs["batch_size"] = (
  1101. dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
  1102. )
  1103. if dispatch_batches:
  1104. kwargs.pop("generator")
  1105. dataloader = DataLoaderDispatcher(
  1106. new_dataset,
  1107. split_batches=split_batches,
  1108. batch_sampler=new_batch_sampler,
  1109. _drop_last=dataloader.drop_last,
  1110. _non_blocking=non_blocking,
  1111. slice_fn=slice_fn_for_dispatch,
  1112. use_stateful_dataloader=use_stateful_dataloader,
  1113. torch_device_mesh=torch_device_mesh,
  1114. **kwargs,
  1115. )
  1116. elif sampler_is_batch_sampler:
  1117. dataloader = DataLoaderShard(
  1118. new_dataset,
  1119. device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
  1120. sampler=new_batch_sampler,
  1121. batch_size=dataloader.batch_size,
  1122. rng_types=rng_types,
  1123. _drop_last=dataloader.drop_last,
  1124. _non_blocking=non_blocking,
  1125. synchronized_generator=synchronized_generator,
  1126. use_stateful_dataloader=use_stateful_dataloader,
  1127. **kwargs,
  1128. )
  1129. else:
  1130. dataloader = DataLoaderShard(
  1131. new_dataset,
  1132. device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
  1133. batch_sampler=new_batch_sampler,
  1134. rng_types=rng_types,
  1135. synchronized_generator=synchronized_generator,
  1136. _drop_last=dataloader.drop_last,
  1137. _non_blocking=non_blocking,
  1138. use_stateful_dataloader=use_stateful_dataloader,
  1139. **kwargs,
  1140. )
  1141. if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
  1142. dataloader.set_sampler(sampler)
  1143. if state.distributed_type == DistributedType.XLA:
  1144. return MpDeviceLoaderWrapper(dataloader, device)
  1145. return dataloader
  1146. class SkipBatchSampler(BatchSampler):
  1147. """
  1148. A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
  1149. Should not be used if the original dataloader is a `StatefulDataLoader`.
  1150. """
  1151. def __init__(self, batch_sampler, skip_batches=0):
  1152. self.batch_sampler = batch_sampler
  1153. self.skip_batches = skip_batches
  1154. def __iter__(self):
  1155. for index, samples in enumerate(self.batch_sampler):
  1156. if index >= self.skip_batches:
  1157. yield samples
  1158. @property
  1159. def total_length(self):
  1160. return len(self.batch_sampler)
  1161. def __len__(self):
  1162. return len(self.batch_sampler) - self.skip_batches
  1163. class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
  1164. """
  1165. Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
  1166. `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
  1167. Args:
  1168. dataset (`torch.utils.data.dataset.Dataset`):
  1169. The dataset to use to build this dataloader.
  1170. skip_batches (`int`, *optional*, defaults to 0):
  1171. The number of batches to skip at the beginning.
  1172. kwargs:
  1173. All other keyword arguments to pass to the regular `DataLoader` initialization.
  1174. """
  1175. def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
  1176. super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
  1177. self.skip_batches = skip_batches
  1178. self.gradient_state = GradientState()
  1179. def __iter__(self):
  1180. self.begin()
  1181. for index, batch in enumerate(self.base_dataloader.__iter__()):
  1182. if index >= self.skip_batches:
  1183. self._update_state_dict()
  1184. yield batch
  1185. self.end()
  1186. def __len__(self):
  1187. return len(self.base_dataloader) - self.skip_batches
  1188. def __reduce__(self):
  1189. """
  1190. Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
  1191. explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
  1192. `__class__` member.
  1193. """
  1194. args = super().__reduce__()
  1195. return (SkipDataLoader, *args[1:])
  1196. def skip_first_batches(dataloader, num_batches=0):
  1197. """
  1198. Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
  1199. the original dataloader is a `StatefulDataLoader`.
  1200. """
  1201. state = PartialState()
  1202. if state.distributed_type == DistributedType.XLA:
  1203. device = dataloader.device
  1204. dataloader = dataloader.dataloader
  1205. dataset = dataloader.dataset
  1206. sampler_is_batch_sampler = False
  1207. if isinstance(dataset, IterableDataset):
  1208. new_batch_sampler = None
  1209. else:
  1210. sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
  1211. batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
  1212. new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
  1213. # We ignore all of those since they are all dealt with by our new_batch_sampler
  1214. ignore_kwargs = [
  1215. "batch_size",
  1216. "shuffle",
  1217. "sampler",
  1218. "batch_sampler",
  1219. "drop_last",
  1220. ]
  1221. kwargs = {
  1222. k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
  1223. for k in _PYTORCH_DATALOADER_KWARGS
  1224. if k not in ignore_kwargs
  1225. }
  1226. # Need to provide batch_size as batch_sampler is None for Iterable dataset
  1227. if new_batch_sampler is None:
  1228. kwargs["drop_last"] = dataloader.drop_last
  1229. kwargs["batch_size"] = dataloader.batch_size
  1230. if isinstance(dataloader, DataLoaderDispatcher):
  1231. if new_batch_sampler is None:
  1232. # Need to manually skip batches in the dataloader
  1233. kwargs["skip_batches"] = num_batches
  1234. dataloader = DataLoaderDispatcher(
  1235. dataset,
  1236. split_batches=dataloader.split_batches,
  1237. batch_sampler=new_batch_sampler,
  1238. _drop_last=dataloader._drop_last,
  1239. **kwargs,
  1240. )
  1241. elif isinstance(dataloader, DataLoaderShard):
  1242. if new_batch_sampler is None:
  1243. # Need to manually skip batches in the dataloader
  1244. kwargs["skip_batches"] = num_batches
  1245. elif sampler_is_batch_sampler:
  1246. kwargs["sampler"] = new_batch_sampler
  1247. kwargs["batch_size"] = dataloader.batch_size
  1248. else:
  1249. kwargs["batch_sampler"] = new_batch_sampler
  1250. dataloader = DataLoaderShard(
  1251. dataset,
  1252. device=dataloader.device,
  1253. rng_types=dataloader.rng_types,
  1254. synchronized_generator=dataloader.synchronized_generator,
  1255. **kwargs,
  1256. )
  1257. else:
  1258. if new_batch_sampler is None:
  1259. # Need to manually skip batches in the dataloader
  1260. dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
  1261. else:
  1262. dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
  1263. if state.distributed_type == DistributedType.XLA:
  1264. dataloader = MpDeviceLoaderWrapper(dataloader, device)
  1265. return dataloader