microbatch.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. import operator
  5. from collections.abc import Sequence
  6. from typing import Any
  7. import torch
  8. from torch.fx.node import map_aggregate
  9. from torch.nn.attention.flex_attention import BlockMask
  10. from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
  11. __all__ = [
  12. "TensorChunkSpec",
  13. "split_args_kwargs_into_chunks",
  14. "merge_chunks",
  15. ]
  16. logger = logging.getLogger(__name__)
  17. """
  18. _debug_mask_minibatches specifies to send masked versions of the mini-batch
  19. through instead of micro-batch slices--this can be used for more stable
  20. numerical testing (see [A Note About Correctness Testing])
  21. """
  22. _debug_mask_minibatches = False
  23. class _CustomReducer:
  24. """
  25. Custom reducer class that can be used to specify a custom operation that
  26. reduces losses of multiple microbatches into one value.
  27. Example:
  28. >>> # xdoctest: +SKIP
  29. >>> sum_reducer = _CustomReducer(
  30. >>> torch.tensor(0.0),
  31. >>> lambda a, b: a + b
  32. >>> )
  33. """
  34. def __init__(self, init_value, reduce_fn):
  35. self.init_value = init_value
  36. self.reduce_fn = reduce_fn
  37. class _LossReducer(_CustomReducer):
  38. pass
  39. sum_reducer = _LossReducer(torch.tensor(0.0), operator.add)
  40. # Default chunking dimension is 0. This is used for the case where the user did
  41. # not specify a chunking dimension.
  42. DEFAULT_CHUNK_DIM = 0
  43. class TensorChunkSpec:
  44. """
  45. Class used to specify chunking of inputs
  46. """
  47. def __init__(self, split_dim):
  48. self.split_dim = split_dim
  49. split_dim: int
  50. def __repr__(self):
  51. return (
  52. f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
  53. )
  54. def __str__(self):
  55. return f"TensorChunkSpec({self.split_dim})"
  56. @staticmethod
  57. def from_tuple(
  58. chunk_dims: tuple[int, ...],
  59. ):
  60. """
  61. A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
  62. dimensions (int's).
  63. Example:
  64. >>> # xdoctest: +SKIP
  65. >>> # There are three positional arguments to the model, and
  66. >>> # we are chunking them along dimension 0, 0 and 1, respectively
  67. >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
  68. """
  69. args_chunk_spec = map_aggregate(
  70. chunk_dims,
  71. lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
  72. )
  73. return args_chunk_spec
  74. @staticmethod
  75. def from_dict(
  76. chunk_dims: dict[str, int],
  77. ):
  78. """
  79. A helper for creating a dictionary of `TensorChunkSpec` from a
  80. dictionary of chunk dimensions (int's).
  81. Example:
  82. >>> # xdoctest: +SKIP
  83. >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
  84. >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
  85. """
  86. kwargs_chunk_spec = map_aggregate(
  87. chunk_dims,
  88. lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
  89. )
  90. return kwargs_chunk_spec
  91. # Class used to specify replication of inputs
  92. class _Replicate:
  93. pass
  94. def _split_block_mask(
  95. block_mask: BlockMask,
  96. num_chunks: int,
  97. ) -> list[BlockMask]:
  98. """Given a block mask, split the block mask along the batch dimension (dim0).
  99. Args:
  100. block_mask: Block mask to split
  101. num_chunks: Number of chunks to split the block mask into
  102. Returns:
  103. chunk_block_masks: List of chunked block masks
  104. """
  105. # BlockMask will broadcast if B is 1.
  106. if block_mask.kv_num_blocks.size(0) == 1:
  107. return [block_mask] * num_chunks
  108. assert block_mask.kv_num_blocks.size(0) >= num_chunks, (
  109. "Block mask has fewer batch size than the number of chunks. "
  110. )
  111. batch_dim = 0
  112. kv_num_blocks_chunks = torch.tensor_split(
  113. block_mask.kv_num_blocks, num_chunks, batch_dim
  114. )
  115. kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
  116. full_kv_num_blocks_chunks = (
  117. torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
  118. if block_mask.full_kv_num_blocks is not None
  119. else [None] * num_chunks
  120. )
  121. full_kv_indices_chunks = (
  122. torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
  123. if block_mask.full_kv_indices is not None
  124. else [None] * num_chunks
  125. )
  126. chunk_block_masks = []
  127. batch_offset = 0
  128. for chunk_idx in range(num_chunks):
  129. def create_mask_mod(idx):
  130. def batch_offset_mask_mod(b, h, q_idx, kv_idx):
  131. b_offset = torch.full_like(b, idx)
  132. return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
  133. return batch_offset_mask_mod
  134. chunk_block_masks.append(
  135. BlockMask.from_kv_blocks(
  136. kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
  137. kv_indices=kv_indices_chunks[chunk_idx],
  138. full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
  139. full_kv_indices=full_kv_indices_chunks[chunk_idx],
  140. BLOCK_SIZE=block_mask.BLOCK_SIZE,
  141. mask_mod=create_mask_mod(batch_offset),
  142. seq_lengths=block_mask.seq_lengths,
  143. )
  144. )
  145. batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
  146. return chunk_block_masks
  147. def _split_tensor(
  148. tensor: torch.Tensor,
  149. spec: TensorChunkSpec,
  150. num_chunks: int,
  151. ) -> Sequence[torch.Tensor]:
  152. """Given a tensor, and a chunking spec, split the tensor.
  153. Args:
  154. tensor: Tensor to split
  155. spec: Chunking spec
  156. num_chunks: Number of chunks to split the tensor into
  157. Returns:
  158. chunk_tensors: List of chunked tensors
  159. """
  160. assert tensor.size(spec.split_dim) >= num_chunks, (
  161. f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
  162. )
  163. chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
  164. if not _debug_mask_minibatches:
  165. return chunk_tensors
  166. expanded_chunks = []
  167. split_dim_idx = 0
  168. for chunk_tensor in chunk_tensors:
  169. new_val = torch.zeros_like(tensor)
  170. upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
  171. slice_indices = [slice(None, None, None)] * new_val.ndim
  172. slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
  173. new_val[slice_indices] = chunk_tensor
  174. expanded_chunks.append(new_val)
  175. split_dim_idx += chunk_tensor.size(spec.split_dim)
  176. return expanded_chunks
  177. def _shard_dict_of_args(
  178. args_dict,
  179. args_chunk_spec,
  180. num_chunks,
  181. ):
  182. """
  183. Given a dictionary of args, and a dictionary of chunking specs, shard the
  184. args according to the chunking specs.
  185. Args:
  186. args_dict: Dictionary of args
  187. args_chunk_spec: Dictionary of chunking specs
  188. num_chunks: Number of chunks to shard the args into
  189. Returns:
  190. args_split: List of sharded args
  191. """
  192. if not args_dict:
  193. return [{} for _ in range(num_chunks)]
  194. assert len(args_dict) == len(args_chunk_spec), (
  195. f"args_dict.keys() = {list(args_dict.keys())} "
  196. f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
  197. )
  198. assert args_chunk_spec is not None # Should have been set by caller
  199. values, tree_spec = tree_flatten(
  200. args_dict, is_leaf=lambda x: isinstance(x, BlockMask)
  201. )
  202. chunk_specs, _ = tree_flatten(
  203. args_chunk_spec, is_leaf=lambda x: isinstance(x, BlockMask)
  204. )
  205. # First check and find the actual number of chunks
  206. split_sizes = []
  207. for v, spec in zip(values, chunk_specs, strict=True):
  208. # The original logic is "spec is _Replicate". This doesn't seem to be
  209. # correct. But we keep it for backward compatibility.
  210. if spec is _Replicate or isinstance(spec, _Replicate):
  211. split_sizes.append(num_chunks)
  212. elif isinstance(v, torch.Tensor):
  213. assert isinstance(spec, TensorChunkSpec)
  214. split_sizes.append(v.size(spec.split_dim))
  215. elif isinstance(v, BlockMask):
  216. assert isinstance(spec, TensorChunkSpec)
  217. assert spec.split_dim == 0, "BlockMask only supports split_dim=0"
  218. # BlockMask will broadcast if B is 1.
  219. if v.kv_num_blocks.size(0) == 1:
  220. split_sizes.append(num_chunks)
  221. else:
  222. split_sizes.append(v.kv_num_blocks.size(0))
  223. else:
  224. raise ValueError(
  225. f"Unsupported chunk spec: {spec} and value: {v} combination."
  226. )
  227. result_num_chunks = min(*split_sizes, num_chunks)
  228. flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
  229. for v, spec in zip(values, chunk_specs, strict=True):
  230. v_splits: Sequence[Any] = []
  231. if spec is _Replicate or isinstance(spec, _Replicate):
  232. v_splits = [v] * result_num_chunks
  233. elif isinstance(v, torch.Tensor):
  234. v_splits = _split_tensor(v, spec, result_num_chunks)
  235. elif isinstance(v, BlockMask):
  236. v_splits = _split_block_mask(v, result_num_chunks)
  237. else:
  238. raise ValueError(
  239. f"Unsupported chunk spec: {spec} and value: {v} combination."
  240. )
  241. for _flat_split_result, _v_split in zip(
  242. flat_split_results, v_splits, strict=True
  243. ):
  244. _flat_split_result.append(_v_split)
  245. return [
  246. tree_unflatten(_flat_split_result, tree_spec)
  247. for _flat_split_result in flat_split_results
  248. ]
  249. def split_args_kwargs_into_chunks(
  250. args: tuple[Any, ...],
  251. kwargs: dict[str, Any] | None,
  252. chunks: int,
  253. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  254. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  255. ) -> tuple[list[tuple], list[dict]]:
  256. """
  257. Given a sequence of args and kwargs, split them into a number of chunks
  258. according to their respective chunking specs.
  259. Args:
  260. args: Tuple of args
  261. kwargs: Dict of kwargs
  262. chunks: Number of chunks to split the args and kwargs into
  263. args_chunk_spec: chunking specs for args, in same shape as args
  264. kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
  265. Returns:
  266. args_split: List of sharded args
  267. kwargs_split: List of sharded kwargs
  268. """
  269. # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
  270. # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
  271. # and `kwargs_chunk_spec` specifications. The steps are as follows:
  272. #
  273. # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
  274. # To use a running example: suppose our inputs look like
  275. #
  276. # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
  277. # (kwargs not shown but it's a similar process)
  278. #
  279. # Then for this step we would end up with
  280. #
  281. # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
  282. #
  283. # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
  284. #
  285. # args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
  286. #
  287. # 3. Rotate the nesting order such that chunks are the outer dimension
  288. #
  289. # args_chunks = [
  290. # ([A, B, C_1], D),
  291. # ([A, B, C_2], D),
  292. # ]
  293. #
  294. # 4. Unflatten each chunk according to the spec
  295. #
  296. # args_chunks = [
  297. # ([A, [B, C_1]], D),
  298. # ([A, [B, C_2]], D),
  299. # ]
  300. # TODO: _debug_mask_minibatches
  301. # Handle the case where kwargs is None
  302. if kwargs is None:
  303. kwargs = {}
  304. # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
  305. # their format and use default chunking along dim 0
  306. def default_spec(v):
  307. if isinstance(v, torch.Tensor | BlockMask):
  308. return TensorChunkSpec(DEFAULT_CHUNK_DIM)
  309. else:
  310. return _Replicate()
  311. if args_chunk_spec is None:
  312. args_chunk_spec = tree_map(
  313. default_spec, args, is_leaf=lambda v: isinstance(v, BlockMask)
  314. )
  315. if kwargs_chunk_spec is None:
  316. kwargs_chunk_spec = tree_map(
  317. default_spec, kwargs, is_leaf=lambda v: isinstance(v, BlockMask)
  318. )
  319. args_split_dict = _shard_dict_of_args(
  320. dict(enumerate(args)),
  321. dict(enumerate(args_chunk_spec)),
  322. chunks,
  323. )
  324. real_num_chunks = len(args_split_dict)
  325. kwargs_split = _shard_dict_of_args(
  326. kwargs,
  327. kwargs_chunk_spec,
  328. real_num_chunks,
  329. )
  330. if len(kwargs_split) < real_num_chunks:
  331. # In case kwargs are sharded into less chunks
  332. # e.g. when `args` has no tensor, just values
  333. real_num_chunks = len(kwargs_split)
  334. # Re-shard args
  335. args_split_dict = _shard_dict_of_args(
  336. dict(enumerate(args)),
  337. dict(enumerate(args_chunk_spec)),
  338. real_num_chunks,
  339. )
  340. if len(args_split_dict) != len(kwargs_split):
  341. raise RuntimeError(
  342. "args and kwargs are split into different number of chunks: "
  343. f"{len(args_split_dict)}, {len(kwargs_split)}"
  344. )
  345. args_split = [
  346. tuple(chunk_args[i] for i in range(len(chunk_args)))
  347. for chunk_args in args_split_dict
  348. ]
  349. return args_split, kwargs_split
  350. def merge_chunks(
  351. chunks: list[Any],
  352. chunk_spec,
  353. ):
  354. """
  355. Given a list of chunks, merge them into a single value according to
  356. the chunk spec.
  357. Args:
  358. chunks: list of chunks
  359. chunk_spec: Chunking spec for the chunks
  360. Returns:
  361. value: Merged value
  362. """
  363. # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
  364. # steps are similar to the steps in that function but in reverse. Given the
  365. # input values:
  366. #
  367. # chunks = [
  368. # ([A, [B, C_1]], D),
  369. # ([A, [B, C_2]], D),
  370. # ]
  371. # args_spec = ([None, [None, TensorChunkSpec]], None)
  372. #
  373. # 1. Flatten the chunks according to the chunk_spec
  374. #
  375. # chunks_flat = [
  376. # ([A, B, C_1], D),
  377. # ([A, B, C_2], D),
  378. # ]
  379. #
  380. # 2. Rotate the nesting order such that chunks are the inner dimension
  381. #
  382. # value_inner = ([A, B, [C_1, C_2]], D)
  383. #
  384. # 3. Concatenate sharded arguments
  385. #
  386. # value_combined = ([A, B, C], D)
  387. #
  388. # 4. Unflatten the combined args given the spec
  389. #
  390. # value = ([A, [B, C]], D)
  391. # Preliminary: flatten the chunk spec
  392. if chunk_spec is not None:
  393. spec_flattened, flatten_spec = tree_flatten(chunk_spec)
  394. else:
  395. # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
  396. # We obtain the output structure by flattening chunk 0 and generate the chunk_spec
  397. chunk0_flat, flatten_spec = tree_flatten(chunks[0])
  398. spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
  399. # Stage 1: flatten chunks
  400. # chunks_flattened : [num chunks, num args]
  401. chunks_flattened = []
  402. for chunk in chunks:
  403. chunk_flattened, _ = tree_flatten(chunk)
  404. if len(chunk_flattened) != len(spec_flattened):
  405. raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
  406. chunks_flattened.append(chunk_flattened)
  407. # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
  408. # concatenate sharded operands
  409. # args_flattened : [num args]
  410. args_flattened = []
  411. for arg_idx, arg in enumerate(spec_flattened):
  412. if isinstance(arg, TensorChunkSpec):
  413. partial_values = [
  414. chunks_flattened[chunk_idx][arg_idx]
  415. for chunk_idx in range(len(chunks_flattened))
  416. ]
  417. if _debug_mask_minibatches:
  418. # Infer size of individual chunks by running `tensor_split` again
  419. overall_shape = partial_values[0].shape
  420. for val in partial_values[1:]:
  421. assert val.shape == overall_shape
  422. meta_chunks = torch.tensor_split(
  423. torch.empty(*overall_shape, device="meta"),
  424. sections=len(partial_values),
  425. dim=arg.split_dim,
  426. )
  427. values_to_cat = []
  428. chunk_start_idx = 0
  429. assert len(partial_values) == len(meta_chunks)
  430. for partial_value, meta_chunk in zip(
  431. partial_values, meta_chunks, strict=True
  432. ):
  433. chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
  434. slice_indices = [slice(None, None, None)] * partial_value.ndim
  435. slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
  436. sliced = partial_value[slice_indices]
  437. values_to_cat.append(sliced)
  438. chunk_start_idx = chunk_end_idx
  439. else:
  440. values_to_cat = partial_values
  441. args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
  442. elif isinstance(arg, _CustomReducer):
  443. reduced_val = arg.init_value
  444. for chunk_idx in range(len(chunks_flattened)):
  445. reduced_val = arg.reduce_fn(
  446. reduced_val, chunks_flattened[chunk_idx][arg_idx]
  447. )
  448. args_flattened.append(reduced_val)
  449. else:
  450. value = chunks_flattened[0][arg_idx]
  451. for chunk_idx in range(1, len(chunks_flattened)):
  452. assert chunks_flattened[chunk_idx][arg_idx] == value
  453. args_flattened.append(value)
  454. # Stage 4: Unflatten combined args
  455. return tree_unflatten(args_flattened, flatten_spec)