microbatch.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. import operator
  5. from typing import Any, Optional
  6. import torch
  7. from torch.fx.node import map_aggregate
  8. from torch.utils._pytree import tree_flatten, tree_unflatten
  9. __all__ = [
  10. "TensorChunkSpec",
  11. "split_args_kwargs_into_chunks",
  12. "merge_chunks",
  13. ]
  14. logger = logging.getLogger(__name__)
  15. """
  16. _debug_mask_minibatches specifies to send masked versions of the mini-batch
  17. through instead of micro-batch slices--this can be used for more stable
  18. numerical testing (see [A Note About Correctness Testing])
  19. """
  20. _debug_mask_minibatches = False
  21. class _CustomReducer:
  22. """
  23. Custom reducer class that can be used to specify a custom operation that
  24. reduces losses of multiple microbatches into one value.
  25. Example:
  26. >>> # xdoctest: +SKIP
  27. >>> sum_reducer = _CustomReducer(
  28. >>> torch.tensor(0.0),
  29. >>> lambda a, b: a + b
  30. >>> )
  31. """
  32. def __init__(self, init_value, reduce_fn):
  33. self.init_value = init_value
  34. self.reduce_fn = reduce_fn
  35. class _LossReducer(_CustomReducer):
  36. pass
  37. sum_reducer = _LossReducer(torch.tensor(0.0), operator.add)
  38. # Default chunking dimension is 0. This is used for the case where the user did
  39. # not specify a chunking dimension.
  40. DEFAULT_CHUNK_DIM = 0
  41. class TensorChunkSpec:
  42. """
  43. Class used to specify chunking of inputs
  44. """
  45. def __init__(self, split_dim):
  46. self.split_dim = split_dim
  47. split_dim: int
  48. def __repr__(self):
  49. return (
  50. f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
  51. )
  52. def __str__(self):
  53. return f"TensorChunkSpec({self.split_dim})"
  54. @staticmethod
  55. def from_tuple(
  56. chunk_dims: tuple[int, ...],
  57. ):
  58. """
  59. A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
  60. dimensions (int's).
  61. Example:
  62. >>> # xdoctest: +SKIP
  63. >>> # There are three positional arguments to the model, and
  64. >>> # we are chunking them along dimension 0, 0 and 1, respectively
  65. >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
  66. """
  67. args_chunk_spec = map_aggregate(
  68. chunk_dims,
  69. lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
  70. )
  71. return args_chunk_spec
  72. @staticmethod
  73. def from_dict(
  74. chunk_dims: dict[str, int],
  75. ):
  76. """
  77. A helper for creating a dictionary of `TensorChunkSpec` from a
  78. dictionary of chunk dimensions (int's).
  79. Example:
  80. >>> # xdoctest: +SKIP
  81. >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
  82. >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
  83. """
  84. kwargs_chunk_spec = map_aggregate(
  85. chunk_dims,
  86. lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
  87. )
  88. return kwargs_chunk_spec
  89. # Class used to specify replication of inputs
  90. class _Replicate:
  91. pass
  92. def _shard_dict_of_args(
  93. args_dict,
  94. args_chunk_spec,
  95. num_chunks,
  96. ):
  97. """
  98. Given a dictionary of args, and a dictionary of chunking specs, shard the
  99. args according to the chunking specs.
  100. Args:
  101. args_dict: Dictionary of args
  102. args_chunk_spec: Dictionary of chunking specs
  103. num_chunks: Number of chunks to shard the args into
  104. Returns:
  105. args_split: List of sharded args
  106. """
  107. # Stage 1+2: flatten and shard/replicate
  108. # args_sharded_replicated : [num args, num flat values, num chunks]
  109. args_sharded_replicated = {}
  110. arg_specs = []
  111. real_num_chunks = num_chunks
  112. first_tensor = True
  113. assert len(args_dict) == len(args_chunk_spec), (
  114. f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
  115. )
  116. for arg_key, arg in args_dict.items():
  117. flat, spec = tree_flatten(arg)
  118. arg_specs.append(spec)
  119. chunk_spec = args_chunk_spec[arg_key]
  120. assert chunk_spec is not None # Should have been set by caller
  121. chunk_spec_flat, _ = tree_flatten(chunk_spec)
  122. if len(flat) != len(chunk_spec_flat):
  123. raise ValueError(
  124. f"Argument value {arg} did not have the same number of "
  125. f"values as as chunk spec {chunk_spec}"
  126. )
  127. sharded_arg_flat = []
  128. for v, chunk_v in zip(flat, chunk_spec_flat):
  129. if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
  130. sharded_arg_flat.append([v] * real_num_chunks)
  131. elif isinstance(chunk_v, TensorChunkSpec):
  132. # TODO: check type of v. If it's a tensor, use chunk (or debug mask).
  133. # If it's a collection type, split it as you would expect. Otherwise,
  134. # Throw an error
  135. assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
  136. v_split_dim_size = v.size(chunk_v.split_dim)
  137. if v_split_dim_size < real_num_chunks:
  138. if first_tensor:
  139. # We can only adjust number of chunks when we hit this
  140. # issue at the first tensor encountered
  141. logger.warning(
  142. f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
  143. f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
  144. )
  145. real_num_chunks = v_split_dim_size
  146. else:
  147. raise RuntimeError(
  148. f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
  149. f"smaller than the number of chunks {num_chunks}. "
  150. "PiPPy cannot reduce the number of chunks because "
  151. "other arguments have bigger chunk-dimension sizes. "
  152. "Please adjust your num_chunks setting."
  153. )
  154. chunk_tensors = torch.tensor_split(
  155. v, real_num_chunks, chunk_v.split_dim
  156. )
  157. if _debug_mask_minibatches:
  158. expanded_chunks = []
  159. split_dim_idx = 0
  160. for chunk_tensor in chunk_tensors:
  161. new_val = torch.zeros_like(v)
  162. upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
  163. slice_indices = [slice(None, None, None)] * new_val.ndim
  164. slice_indices[chunk_v.split_dim] = slice(
  165. split_dim_idx, upper_idx
  166. )
  167. new_val[slice_indices] = chunk_tensor
  168. expanded_chunks.append(new_val)
  169. split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
  170. sharded_arg_flat.append(expanded_chunks)
  171. else:
  172. sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
  173. first_tensor = False
  174. else:
  175. raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
  176. args_sharded_replicated[arg_key] = sharded_arg_flat
  177. # chunks_flat : [num chunks, num args, num flat values]
  178. chunks_flat = []
  179. for chunk_idx in range(real_num_chunks):
  180. chunk_args = {}
  181. for key, arg in args_sharded_replicated.items():
  182. arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
  183. chunk_args[key] = arg_single_chunk
  184. chunks_flat.append(chunk_args)
  185. # args_split : [num chunks, num args]
  186. args_split = []
  187. for chunk in chunks_flat:
  188. per_chunk_args = {}
  189. assert len(arg_specs) == len(chunk)
  190. for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
  191. per_chunk_args[key] = tree_unflatten(arg, arg_spec)
  192. args_split.append(per_chunk_args)
  193. return args_split
  194. def split_args_kwargs_into_chunks(
  195. args: tuple[Any, ...],
  196. kwargs: Optional[dict[str, Any]],
  197. chunks: int,
  198. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  199. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  200. ) -> tuple[list[tuple], list[dict]]:
  201. """
  202. Given a sequence of args and kwargs, split them into a number of chunks
  203. according to their respective chunking specs.
  204. Args:
  205. args: Tuple of args
  206. kwargs: Dict of kwargs
  207. chunks: Number of chunks to split the args and kwargs into
  208. args_chunk_spec: chunking specs for args, in same shape as args
  209. kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
  210. Returns:
  211. args_split: List of sharded args
  212. kwargs_split: List of sharded kwargs
  213. """
  214. # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
  215. # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
  216. # and `kwargs_chunk_spec` specifications. The steps are as follows:
  217. #
  218. # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
  219. # To use a running example: suppose our inputs look like
  220. #
  221. # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
  222. # (kwargs not shown but it's a similar process)
  223. #
  224. # Then for this step we would end up with
  225. #
  226. # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
  227. #
  228. # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
  229. #
  230. # args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
  231. #
  232. # 3. Rotate the nesting order such that chunks are the outer dimension
  233. #
  234. # args_chunks = [
  235. # ([A, B, C_1], D),
  236. # ([A, B, C_2], D),
  237. # ]
  238. #
  239. # 4. Unflatten each chunk according to the spec
  240. #
  241. # args_chunks = [
  242. # ([A, [B, C_1]], D),
  243. # ([A, [B, C_2]], D),
  244. # ]
  245. # TODO: _debug_mask_minibatches
  246. # Handle the case where kwargs is None
  247. if kwargs is None:
  248. kwargs = {}
  249. # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
  250. # their format and use default chunking along dim 0
  251. if args_chunk_spec is None:
  252. args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
  253. if kwargs_chunk_spec is None:
  254. kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
  255. args_split_dict = _shard_dict_of_args(
  256. dict(enumerate(args)),
  257. dict(enumerate(args_chunk_spec)),
  258. chunks,
  259. )
  260. real_num_chunks = len(args_split_dict)
  261. kwargs_split = _shard_dict_of_args(
  262. kwargs,
  263. kwargs_chunk_spec,
  264. real_num_chunks,
  265. )
  266. if len(kwargs_split) < real_num_chunks:
  267. # In case kwargs are sharded into less chunks
  268. # e.g. when `args` has no tensor, just values
  269. real_num_chunks = len(kwargs_split)
  270. # Re-shard args
  271. args_split_dict = _shard_dict_of_args(
  272. dict(enumerate(args)),
  273. dict(enumerate(args_chunk_spec)),
  274. real_num_chunks,
  275. )
  276. if len(args_split_dict) != len(kwargs_split):
  277. raise RuntimeError(
  278. "args and kwargs are split into different number of chunks: "
  279. f"{len(args_split_dict)}, {len(kwargs_split)}"
  280. )
  281. args_split = [
  282. tuple(chunk_args[i] for i in range(len(chunk_args)))
  283. for chunk_args in args_split_dict
  284. ]
  285. return args_split, kwargs_split
  286. def merge_chunks(
  287. chunks: list[Any],
  288. chunk_spec,
  289. ):
  290. """
  291. Given a list of chunks, merge them into a single value according to
  292. the chunk spec.
  293. Args:
  294. chunks: list of chunks
  295. chunk_spec: Chunking spec for the chunks
  296. Returns:
  297. value: Merged value
  298. """
  299. # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
  300. # steps are similar to the steps in that function but in reverse. Given the
  301. # input values:
  302. #
  303. # chunks = [
  304. # ([A, [B, C_1]], D),
  305. # ([A, [B, C_2]], D),
  306. # ]
  307. # args_spec = ([None, [None, TensorChunkSpec]], None)
  308. #
  309. # 1. Flatten the chunks according to the chunk_spec
  310. #
  311. # chunks_flat = [
  312. # ([A, B, C_1], D),
  313. # ([A, B, C_2], D),
  314. # ]
  315. #
  316. # 2. Rotate the nesting order such that chunks are the inner dimension
  317. #
  318. # value_inner = ([A, B, [C_1, C_2]], D)
  319. #
  320. # 3. Concatenate sharded arguments
  321. #
  322. # value_combined = ([A, B, C], D)
  323. #
  324. # 4. Unflatten the combined args given the spec
  325. #
  326. # value = ([A, [B, C]], D)
  327. # Preliminary: flatten the chunk spec
  328. if chunk_spec is not None:
  329. spec_flattened, flatten_spec = tree_flatten(chunk_spec)
  330. else:
  331. # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
  332. # We obtain the output structure by flattening chunk 0 and generate the chunk_spec
  333. chunk0_flat, flatten_spec = tree_flatten(chunks[0])
  334. spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
  335. # Stage 1: flatten chunks
  336. # chunks_flattened : [num chunks, num args]
  337. chunks_flattened = []
  338. for chunk in chunks:
  339. chunk_flattened, _ = tree_flatten(chunk)
  340. if len(chunk_flattened) != len(spec_flattened):
  341. raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
  342. chunks_flattened.append(chunk_flattened)
  343. # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
  344. # concatenate sharded operands
  345. # args_flattened : [num args]
  346. args_flattened = []
  347. for arg_idx, arg in enumerate(spec_flattened):
  348. if isinstance(arg, TensorChunkSpec):
  349. partial_values = [
  350. chunks_flattened[chunk_idx][arg_idx]
  351. for chunk_idx in range(len(chunks_flattened))
  352. ]
  353. if _debug_mask_minibatches:
  354. # Infer size of individual chunks by running `tensor_split` again
  355. overall_shape = partial_values[0].shape
  356. for val in partial_values[1:]:
  357. assert val.shape == overall_shape
  358. meta_chunks = torch.tensor_split(
  359. torch.empty(*overall_shape, device="meta"),
  360. sections=len(partial_values),
  361. dim=arg.split_dim,
  362. )
  363. values_to_cat = []
  364. chunk_start_idx = 0
  365. assert len(partial_values) == len(meta_chunks)
  366. for partial_value, meta_chunk in zip(partial_values, meta_chunks):
  367. chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
  368. slice_indices = [slice(None, None, None)] * partial_value.ndim
  369. slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
  370. sliced = partial_value[slice_indices]
  371. values_to_cat.append(sliced)
  372. chunk_start_idx = chunk_end_idx
  373. else:
  374. values_to_cat = partial_values
  375. args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
  376. elif isinstance(arg, _CustomReducer):
  377. reduced_val = arg.init_value
  378. for chunk_idx in range(len(chunks_flattened)):
  379. reduced_val = arg.reduce_fn(
  380. reduced_val, chunks_flattened[chunk_idx][arg_idx]
  381. )
  382. args_flattened.append(reduced_val)
  383. else:
  384. value = chunks_flattened[0][arg_idx]
  385. for chunk_idx in range(1, len(chunks_flattened)):
  386. assert chunks_flattened[chunk_idx][arg_idx] == value
  387. args_flattened.append(value)
  388. # Stage 4: Unflatten combined args
  389. return tree_unflatten(args_flattened, flatten_spec)