operations.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. A set of basic tensor ops compatible with tpu, gpu, and multigpu
  16. """
  17. import pickle
  18. import warnings
  19. from collections.abc import Mapping
  20. from contextlib import contextmanager, nullcontext
  21. from functools import update_wrapper, wraps
  22. from typing import Any
  23. import torch
  24. from ..state import AcceleratorState, PartialState
  25. from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
  26. from .dataclasses import DistributedType, TensorInformation
  27. from .imports import (
  28. is_npu_available,
  29. is_torch_distributed_available,
  30. is_torch_xla_available,
  31. )
  32. from .versions import is_torch_version
  33. if is_torch_xla_available():
  34. import torch_xla.core.xla_model as xm
  35. if is_torch_distributed_available():
  36. from torch.distributed import ReduceOp
  37. def is_torch_tensor(tensor):
  38. return isinstance(tensor, torch.Tensor)
  39. def is_torch_xpu_tensor(tensor):
  40. return isinstance(
  41. tensor,
  42. torch.xpu.FloatTensor,
  43. torch.xpu.ByteTensor,
  44. torch.xpu.IntTensor,
  45. torch.xpu.LongTensor,
  46. torch.xpu.HalfTensor,
  47. torch.xpu.DoubleTensor,
  48. torch.xpu.BFloat16Tensor,
  49. )
  50. def is_tensor_information(tensor_info):
  51. return isinstance(tensor_info, TensorInformation)
  52. def is_namedtuple(data):
  53. """
  54. Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
  55. `namedtuple` perfectly.
  56. """
  57. return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields")
  58. def honor_type(obj, generator):
  59. """
  60. Cast a generator to the same type as obj (list, tuple, or namedtuple)
  61. """
  62. # Some objects may not be able to instantiate from a generator directly
  63. if is_namedtuple(obj):
  64. return type(obj)(*list(generator))
  65. else:
  66. return type(obj)(generator)
  67. def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):
  68. """
  69. Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
  70. Args:
  71. func (`callable`):
  72. The function to recursively apply.
  73. data (nested list/tuple/dictionary of `main_type`):
  74. The data on which to apply `func`
  75. *args:
  76. Positional arguments that will be passed to `func` when applied on the unpacked data.
  77. main_type (`type`, *optional*, defaults to `torch.Tensor`):
  78. The base type of the objects to which apply `func`.
  79. error_on_other_type (`bool`, *optional*, defaults to `False`):
  80. Whether to return an error or not if after unpacking `data`, we get on an object that is not of type
  81. `main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged.
  82. **kwargs (additional keyword arguments, *optional*):
  83. Keyword arguments that will be passed to `func` when applied on the unpacked data.
  84. Returns:
  85. The same data structure as `data` with `func` applied to every object of type `main_type`.
  86. """
  87. if isinstance(data, (tuple, list)):
  88. return honor_type(
  89. data,
  90. (
  91. recursively_apply(
  92. func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
  93. )
  94. for o in data
  95. ),
  96. )
  97. elif isinstance(data, Mapping):
  98. return type(data)(
  99. {
  100. k: recursively_apply(
  101. func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
  102. )
  103. for k, v in data.items()
  104. }
  105. )
  106. elif test_type(data):
  107. return func(data, *args, **kwargs)
  108. elif error_on_other_type:
  109. raise TypeError(
  110. f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of "
  111. f"objects that are valid for `{test_type.__name__}` should be passed."
  112. )
  113. return data
  114. def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
  115. """
  116. Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
  117. Args:
  118. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  119. The data to send to a given device.
  120. device (`torch.device`):
  121. The device to send the data to.
  122. Returns:
  123. The same data structure as `tensor` with all tensors sent to the proper device.
  124. """
  125. if is_torch_tensor(tensor) or hasattr(tensor, "to"):
  126. # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).
  127. if device == "npu":
  128. device = "npu:0"
  129. try:
  130. return tensor.to(device, non_blocking=non_blocking)
  131. except TypeError: # .to() doesn't accept non_blocking as kwarg
  132. return tensor.to(device)
  133. except AssertionError as error:
  134. # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
  135. # This call is inside the try-block since is_npu_available is not supported by torch.compile.
  136. if is_npu_available():
  137. if isinstance(device, int):
  138. device = f"npu:{device}"
  139. else:
  140. raise error
  141. try:
  142. return tensor.to(device, non_blocking=non_blocking)
  143. except TypeError: # .to() doesn't accept non_blocking as kwarg
  144. return tensor.to(device)
  145. elif isinstance(tensor, (tuple, list)):
  146. return honor_type(
  147. tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
  148. )
  149. elif isinstance(tensor, Mapping):
  150. if isinstance(skip_keys, str):
  151. skip_keys = [skip_keys]
  152. elif skip_keys is None:
  153. skip_keys = []
  154. return type(tensor)(
  155. {
  156. k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)
  157. for k, t in tensor.items()
  158. }
  159. )
  160. else:
  161. return tensor
  162. def get_data_structure(data):
  163. """
  164. Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors.
  165. Args:
  166. data (nested list/tuple/dictionary of `torch.Tensor`):
  167. The data to send to analyze.
  168. Returns:
  169. The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors.
  170. """
  171. def _get_data_structure(tensor):
  172. return TensorInformation(shape=tensor.shape, dtype=tensor.dtype)
  173. return recursively_apply(_get_data_structure, data)
  174. def get_shape(data):
  175. """
  176. Recursively gathers the shape of a nested list/tuple/dictionary of tensors as a list.
  177. Args:
  178. data (nested list/tuple/dictionary of `torch.Tensor`):
  179. The data to send to analyze.
  180. Returns:
  181. The same data structure as `data` with lists of tensor shapes instead of tensors.
  182. """
  183. def _get_shape(tensor):
  184. return list(tensor.shape)
  185. return recursively_apply(_get_shape, data)
  186. def initialize_tensors(data_structure):
  187. """
  188. Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].
  189. Returns:
  190. The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
  191. """
  192. def _initialize_tensor(tensor_info):
  193. return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype)
  194. return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)
  195. def find_batch_size(data):
  196. """
  197. Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.
  198. Args:
  199. data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
  200. Returns:
  201. `int`: The batch size.
  202. """
  203. if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0):
  204. raise ValueError(f"Cannot find the batch size from empty {type(data)}.")
  205. if isinstance(data, (tuple, list)):
  206. return find_batch_size(data[0])
  207. elif isinstance(data, Mapping):
  208. for k in data.keys():
  209. return find_batch_size(data[k])
  210. elif not isinstance(data, torch.Tensor):
  211. raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
  212. return data.shape[0]
  213. def ignorant_find_batch_size(data):
  214. """
  215. Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised
  216. Args:
  217. data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
  218. Returns:
  219. `int`: The batch size.
  220. """
  221. try:
  222. return find_batch_size(data)
  223. except (ValueError, TypeError):
  224. pass
  225. return None
  226. def listify(data):
  227. """
  228. Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers.
  229. Args:
  230. data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers.
  231. Returns:
  232. The same data structure as `data` with lists of numbers instead of `torch.Tensor`.
  233. """
  234. def _convert_to_list(tensor):
  235. tensor = tensor.detach().cpu()
  236. if tensor.dtype == torch.bfloat16:
  237. # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
  238. # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
  239. # Until Numpy adds bfloat16, we must convert float32.
  240. tensor = tensor.to(torch.float32)
  241. return tensor.tolist()
  242. return recursively_apply(_convert_to_list, data)
  243. def _tpu_gather(tensor):
  244. def _tpu_gather_one(tensor):
  245. if tensor.ndim == 0:
  246. tensor = tensor.clone()[None]
  247. # Can only gather contiguous tensors
  248. if not tensor.is_contiguous():
  249. tensor = tensor.contiguous()
  250. return xm.all_gather(tensor)
  251. res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True)
  252. xm.mark_step()
  253. return res
  254. def _gpu_gather(tensor):
  255. state = PartialState()
  256. gather_op = torch.distributed.all_gather_into_tensor
  257. # NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
  258. if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
  259. torch.xpu.synchronize()
  260. def _gpu_gather_one(tensor):
  261. if tensor.ndim == 0:
  262. tensor = tensor.clone()[None]
  263. # Can only gather contiguous tensors
  264. if not tensor.is_contiguous():
  265. tensor = tensor.contiguous()
  266. if state.backend is not None and state.backend != "gloo":
  267. # We use `empty` as `all_gather_into_tensor` slightly
  268. # differs from `all_gather` for better efficiency,
  269. # and we rely on the number of items in the tensor
  270. # rather than its direct shape
  271. output_tensors = torch.empty(
  272. state.num_processes * tensor.numel(),
  273. dtype=tensor.dtype,
  274. device=state.device,
  275. )
  276. gather_op(output_tensors, tensor)
  277. return output_tensors.view(-1, *tensor.size()[1:])
  278. else:
  279. # a backend of `None` is always CPU
  280. # also gloo does not support `all_gather_into_tensor`,
  281. # which will result in a larger memory overhead for the op
  282. output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
  283. torch.distributed.all_gather(output_tensors, tensor)
  284. return torch.cat(output_tensors, dim=0)
  285. return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
  286. class DistributedOperationException(Exception):
  287. """
  288. An exception class for distributed operations. Raised if the operation cannot be performed due to the shape of the
  289. tensors.
  290. """
  291. pass
  292. def verify_operation(function):
  293. """
  294. Verifies that `tensor` is the same shape across all processes. Only ran if `PartialState().debug` is `True`.
  295. """
  296. @wraps(function)
  297. def wrapper(*args, **kwargs):
  298. if PartialState().distributed_type == DistributedType.NO or not PartialState().debug:
  299. return function(*args, **kwargs)
  300. operation = f"{function.__module__}.{function.__name__}"
  301. if "tensor" in kwargs:
  302. tensor = kwargs["tensor"]
  303. else:
  304. tensor = args[0]
  305. if PartialState().device.type != find_device(tensor).type:
  306. raise DistributedOperationException(
  307. f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. "
  308. f"Please move it to the {PartialState().device.type} before calling {operation}."
  309. )
  310. shapes = get_shape(tensor)
  311. output = gather_object([shapes])
  312. if output[0] is not None:
  313. are_same = output.count(output[0]) == len(output)
  314. if not are_same:
  315. process_shape_str = "\n - ".join([f"Process {i}: {shape}" for i, shape in enumerate(output)])
  316. raise DistributedOperationException(
  317. f"Cannot apply desired operation due to shape mismatches. "
  318. "All shapes across devices must be valid."
  319. f"\n\nOperation: `{operation}`\nInput shapes:\n - {process_shape_str}"
  320. )
  321. return function(*args, **kwargs)
  322. return wrapper
  323. def chained_operation(function):
  324. """
  325. Checks that `verify_operation` failed and if so reports a more helpful error chaining the existing
  326. `DistributedOperationException`.
  327. """
  328. @wraps(function)
  329. def wrapper(*args, **kwargs):
  330. try:
  331. return function(*args, **kwargs)
  332. except DistributedOperationException as e:
  333. operation = f"{function.__module__}.{function.__name__}"
  334. raise DistributedOperationException(
  335. f"Error found while calling `{operation}`. Please see the earlier error for more details."
  336. ) from e
  337. return wrapper
  338. @verify_operation
  339. def gather(tensor):
  340. """
  341. Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
  342. Args:
  343. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  344. The data to gather.
  345. Returns:
  346. The same data structure as `tensor` with all tensors sent to the proper device.
  347. """
  348. if PartialState().distributed_type == DistributedType.XLA:
  349. return _tpu_gather(tensor)
  350. elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
  351. return _gpu_gather(tensor)
  352. else:
  353. return tensor
  354. def _gpu_gather_object(object: Any):
  355. output_objects = [None for _ in range(PartialState().num_processes)]
  356. torch.distributed.all_gather_object(output_objects, object)
  357. # all_gather_object returns a list of lists, so we need to flatten it
  358. return [x for y in output_objects for x in y]
  359. def gather_object(object: Any):
  360. """
  361. Recursively gather object in a nested list/tuple/dictionary of objects from all devices.
  362. Args:
  363. object (nested list/tuple/dictionary of picklable object):
  364. The data to gather.
  365. Returns:
  366. The same data structure as `object` with all the objects sent to every device.
  367. """
  368. if PartialState().distributed_type == DistributedType.XLA:
  369. raise NotImplementedError("gather objects in TPU is not supported")
  370. elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
  371. return _gpu_gather_object(object)
  372. else:
  373. return object
  374. def _gpu_broadcast(data, src=0):
  375. def _gpu_broadcast_one(tensor, src=0):
  376. torch.distributed.broadcast(tensor, src=src)
  377. return tensor
  378. return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src)
  379. def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
  380. if isinstance(tensor, (list, tuple)):
  381. return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
  382. elif isinstance(tensor, Mapping):
  383. return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()})
  384. return xm.mesh_reduce(name, tensor, lambda x: x[src])
  385. TENSOR_TYPE_TO_INT = {
  386. torch.float: 1,
  387. torch.double: 2,
  388. torch.half: 3,
  389. torch.bfloat16: 4,
  390. torch.uint8: 5,
  391. torch.int8: 6,
  392. torch.int16: 7,
  393. torch.int32: 8,
  394. torch.int64: 9,
  395. torch.bool: 10,
  396. }
  397. TENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()}
  398. def gather_tensor_shape(tensor):
  399. """
  400. Grabs the shape of `tensor` only available on one process and returns a tensor of its shape
  401. """
  402. # Allocate 80 bytes to store the shape
  403. max_tensor_dimension = 2**20
  404. state = PartialState()
  405. base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device)
  406. # Since PyTorch can't just send a tensor to another GPU without
  407. # knowing its size, we store the size of the tensor with data
  408. # in an allocation
  409. if tensor is not None:
  410. shape = tensor.shape
  411. tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype]
  412. base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int)
  413. # Perform a reduction to copy the size data onto all GPUs
  414. base_tensor = reduce(base_tensor, reduction="sum")
  415. base_tensor = base_tensor[base_tensor.nonzero()]
  416. # The last non-zero data contains the coded dtype the source tensor is
  417. dtype = int(base_tensor[-1:][0])
  418. base_tensor = base_tensor[:-1]
  419. return base_tensor, dtype
  420. def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
  421. """
  422. Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
  423. each worker doesn't need to know its shape when used (and tensor can be `None`)
  424. Args:
  425. tensor (`torch.tensor`):
  426. The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest
  427. should be `None`.
  428. """
  429. state = PartialState()
  430. shape, dtype = gather_tensor_shape(tensor)
  431. if tensor is None:
  432. tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device)
  433. return reduce(tensor, reduction="sum")
  434. @verify_operation
  435. def broadcast(tensor, from_process: int = 0):
  436. """
  437. Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
  438. Args:
  439. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  440. The data to gather.
  441. from_process (`int`, *optional*, defaults to 0):
  442. The process from which to send the data
  443. Returns:
  444. The same data structure as `tensor` with all tensors broadcasted to the proper device.
  445. """
  446. if PartialState().distributed_type == DistributedType.XLA:
  447. return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast")
  448. elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
  449. return _gpu_broadcast(tensor, src=from_process)
  450. else:
  451. return tensor
  452. def broadcast_object_list(object_list, from_process: int = 0):
  453. """
  454. Broadcast a list of picklable objects from one process to the others.
  455. Args:
  456. object_list (list of picklable objects):
  457. The list of objects to broadcast. This list will be modified inplace.
  458. from_process (`int`, *optional*, defaults to 0):
  459. The process from which to send the data.
  460. Returns:
  461. The same list containing the objects from process 0.
  462. """
  463. if PartialState().distributed_type == DistributedType.XLA:
  464. for i, obj in enumerate(object_list):
  465. object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process])
  466. elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
  467. torch.distributed.broadcast_object_list(object_list, src=from_process)
  468. return object_list
  469. def slice_tensors(data, tensor_slice, process_index=None, num_processes=None):
  470. """
  471. Recursively takes a slice in a nested list/tuple/dictionary of tensors.
  472. Args:
  473. data (nested list/tuple/dictionary of `torch.Tensor`):
  474. The data to slice.
  475. tensor_slice (`slice`):
  476. The slice to take.
  477. Returns:
  478. The same data structure as `data` with all the tensors slices.
  479. """
  480. def _slice_tensor(tensor, tensor_slice):
  481. return tensor[tensor_slice]
  482. return recursively_apply(_slice_tensor, data, tensor_slice)
  483. def concatenate(data, dim=0):
  484. """
  485. Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
  486. Args:
  487. data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
  488. The data to concatenate.
  489. dim (`int`, *optional*, defaults to 0):
  490. The dimension on which to concatenate.
  491. Returns:
  492. The same data structure as `data` with all the tensors concatenated.
  493. """
  494. if isinstance(data[0], (tuple, list)):
  495. return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
  496. elif isinstance(data[0], Mapping):
  497. return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
  498. elif not isinstance(data[0], torch.Tensor):
  499. raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
  500. return torch.cat(data, dim=dim)
  501. class CannotPadNestedTensorWarning(UserWarning):
  502. pass
  503. @chained_operation
  504. def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
  505. """
  506. Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they
  507. can safely be gathered.
  508. Args:
  509. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  510. The data to gather.
  511. dim (`int`, *optional*, defaults to 0):
  512. The dimension on which to pad.
  513. pad_index (`int`, *optional*, defaults to 0):
  514. The value with which to pad.
  515. pad_first (`bool`, *optional*, defaults to `False`):
  516. Whether to pad at the beginning or the end.
  517. """
  518. def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
  519. if getattr(tensor, "is_nested", False):
  520. warnings.warn(
  521. "Cannot pad nested tensors without more information. Leaving unprocessed.",
  522. CannotPadNestedTensorWarning,
  523. )
  524. return tensor
  525. if dim >= len(tensor.shape) or dim < -len(tensor.shape):
  526. return tensor
  527. # Convert negative dimensions to non-negative
  528. if dim < 0:
  529. dim += len(tensor.shape)
  530. # Gather all sizes
  531. size = torch.tensor(tensor.shape, device=tensor.device)[None]
  532. sizes = gather(size).cpu()
  533. # Then pad to the maximum size
  534. max_size = max(s[dim] for s in sizes)
  535. if max_size == tensor.shape[dim]:
  536. return tensor
  537. old_size = tensor.shape
  538. new_size = list(old_size)
  539. new_size[dim] = max_size
  540. new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
  541. if pad_first:
  542. indices = tuple(
  543. slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))
  544. )
  545. else:
  546. indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
  547. new_tensor[indices] = tensor
  548. return new_tensor
  549. return recursively_apply(
  550. _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
  551. )
  552. def pad_input_tensors(tensor, batch_size, num_processes, dim=0):
  553. """
  554. Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions.
  555. New tensors are just the last input repeated.
  556. E.g.:
  557. Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4])
  558. """
  559. def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):
  560. remainder = batch_size // num_processes
  561. last_inputs = batch_size - (remainder * num_processes)
  562. if batch_size // num_processes == 0:
  563. to_pad = num_processes - batch_size
  564. else:
  565. to_pad = num_processes - (batch_size // num_processes)
  566. # In the rare case that `to_pad` is negative,
  567. # we need to pad the last inputs - the found `to_pad`
  568. if last_inputs > to_pad & to_pad < 1:
  569. to_pad = last_inputs - to_pad
  570. old_size = tensor.shape
  571. new_size = list(old_size)
  572. new_size[0] = batch_size + to_pad
  573. new_tensor = tensor.new_zeros(tuple(new_size))
  574. indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
  575. new_tensor[indices] = tensor
  576. return new_tensor
  577. return recursively_apply(
  578. _pad_input_tensors,
  579. tensor,
  580. error_on_other_type=True,
  581. batch_size=batch_size,
  582. num_processes=num_processes,
  583. dim=dim,
  584. )
  585. @verify_operation
  586. def reduce(tensor, reduction="mean", scale=1.0):
  587. """
  588. Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the
  589. mean of a given operation.
  590. Args:
  591. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  592. The data to reduce.
  593. reduction (`str`, *optional*, defaults to `"mean"`):
  594. A reduction method. Can be of "mean", "sum", or "none"
  595. scale (`float`, *optional*):
  596. A default scaling value to be applied after the reduce, only valid on XLA.
  597. Returns:
  598. The same data structure as `data` with all the tensors reduced.
  599. """
  600. def _reduce_across_processes(tensor, reduction="mean", scale=1.0):
  601. state = PartialState()
  602. cloned_tensor = tensor.clone()
  603. if state.distributed_type == DistributedType.NO:
  604. return cloned_tensor
  605. if state.distributed_type == DistributedType.XLA:
  606. # Some processes may have different HLO graphs than other
  607. # processes, for example in the breakpoint API
  608. # accelerator.set_trigger(). Use mark_step to make HLOs
  609. # the same on all processes.
  610. xm.mark_step()
  611. xm.all_reduce(xm.REDUCE_SUM, [cloned_tensor], scale)
  612. xm.mark_step()
  613. elif state.distributed_type.value in TORCH_DISTRIBUTED_OPERATION_TYPES:
  614. torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM)
  615. if reduction == "mean":
  616. cloned_tensor /= state.num_processes
  617. return cloned_tensor
  618. return recursively_apply(
  619. _reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale
  620. )
  621. def convert_to_fp32(tensor):
  622. """
  623. Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
  624. Args:
  625. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  626. The data to convert from FP16/BF16 to FP32.
  627. Returns:
  628. The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.
  629. """
  630. def _convert_to_fp32(tensor):
  631. return tensor.float()
  632. def _is_fp16_bf16_tensor(tensor):
  633. return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in (
  634. torch.float16,
  635. torch.bfloat16,
  636. )
  637. return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
  638. class ConvertOutputsToFp32:
  639. """
  640. Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
  641. precision will be convert back to FP32.
  642. Args:
  643. model_forward (`Callable`):
  644. The function which outputs we want to treat.
  645. Returns:
  646. The same function as `model_forward` but with converted outputs.
  647. """
  648. def __init__(self, model_forward):
  649. self.model_forward = model_forward
  650. update_wrapper(self, model_forward)
  651. def __call__(self, *args, **kwargs):
  652. return convert_to_fp32(self.model_forward(*args, **kwargs))
  653. def __getstate__(self):
  654. raise pickle.PicklingError(
  655. "Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it."
  656. )
  657. def convert_outputs_to_fp32(model_forward):
  658. model_forward = ConvertOutputsToFp32(model_forward)
  659. def forward(*args, **kwargs):
  660. return model_forward(*args, **kwargs)
  661. # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
  662. forward.__wrapped__ = model_forward
  663. return forward
  664. def find_device(data):
  665. """
  666. Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
  667. Args:
  668. (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
  669. """
  670. if isinstance(data, Mapping):
  671. for obj in data.values():
  672. device = find_device(obj)
  673. if device is not None:
  674. return device
  675. elif isinstance(data, (tuple, list)):
  676. for obj in data:
  677. device = find_device(obj)
  678. if device is not None:
  679. return device
  680. elif isinstance(data, torch.Tensor):
  681. return data.device
  682. @contextmanager
  683. def GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True):
  684. """
  685. Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context
  686. manager.
  687. """
  688. # We need to use the `AcceleratorState` here since it has access to the deepspeed plugin
  689. if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or (
  690. AcceleratorState().deepspeed_plugin is not None
  691. and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled()
  692. ):
  693. gather_param_context = nullcontext()
  694. else:
  695. import deepspeed
  696. gather_param_context = deepspeed.zero.GatheredParameters(
  697. params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled
  698. )
  699. with gather_param_context:
  700. yield