flop_counter.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
  4. from .module_tracker import ModuleTracker
  5. from typing import Any, TypeVar
  6. from collections.abc import Callable
  7. from collections.abc import Iterator
  8. from typing_extensions import ParamSpec
  9. from collections import defaultdict
  10. from torch.utils._python_dispatch import TorchDispatchMode
  11. from math import prod
  12. from functools import wraps
  13. import warnings
  14. __all__ = ["FlopCounterMode", "register_flop_formula"]
  15. _T = TypeVar("_T")
  16. _P = ParamSpec("_P")
  17. aten = torch.ops.aten
  18. def get_shape(i):
  19. if isinstance(i, torch.Tensor):
  20. return i.shape
  21. return i
  22. flop_registry: dict[Any, Any] = {}
  23. def shape_wrapper(f):
  24. @wraps(f)
  25. def nf(*args, out_val=None, **kwargs):
  26. args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val))
  27. return f(*args, out_shape=out_shape, **kwargs)
  28. return nf
  29. def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  30. def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
  31. if not get_raw:
  32. flop_formula = shape_wrapper(flop_formula)
  33. def register(target) -> None:
  34. if not isinstance(target, torch._ops.OpOverloadPacket):
  35. raise ValueError(
  36. f"register_flop_formula(targets): expected each target to be "
  37. f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got "
  38. f"{target} which is of type {type(target)}")
  39. if target in flop_registry:
  40. raise RuntimeError(f"duplicate registrations for {target}")
  41. flop_registry[target] = flop_formula
  42. # To handle allowing multiple aten_ops at once
  43. torch.utils._pytree.tree_map_(register, targets)
  44. return flop_formula
  45. return register_fun
  46. @register_flop_formula(aten.mm)
  47. def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
  48. """Count flops for matmul."""
  49. # Inputs should be a list of length 2.
  50. # Inputs contains the shapes of two matrices.
  51. m, k = a_shape
  52. k2, n = b_shape
  53. if k != k2:
  54. raise AssertionError(f"matmul: inner dimensions must match (k == k2), got {k} and {k2}")
  55. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  56. return m * n * 2 * k
  57. @register_flop_formula(aten.addmm)
  58. def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  59. """Count flops for addmm."""
  60. return mm_flop(a_shape, b_shape)
  61. @register_flop_formula(aten.bmm)
  62. def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
  63. """Count flops for the bmm operation."""
  64. # Inputs should be a list of length 2.
  65. # Inputs contains the shapes of two tensor.
  66. b, m, k = a_shape
  67. b2, k2, n = b_shape
  68. if b != b2:
  69. raise AssertionError(f"bmm: batch dimensions must match (b == b2), got {b} and {b2}")
  70. if k != k2:
  71. raise AssertionError(f"bmm: inner dimensions must match (k == k2), got {k} and {k2}")
  72. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  73. flop = b * m * n * 2 * k
  74. return flop
  75. @register_flop_formula(aten.baddbmm)
  76. def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  77. """Count flops for the baddbmm operation."""
  78. # Inputs should be a list of length 3.
  79. # Inputs contains the shapes of three tensors.
  80. return bmm_flop(a_shape, b_shape)
  81. @register_flop_formula(aten._scaled_mm)
  82. def _scaled_mm_flop(
  83. a_shape,
  84. b_shape,
  85. scale_a_shape,
  86. scale_b_shape,
  87. bias_shape=None,
  88. scale_result_shape=None,
  89. out_dtype=None,
  90. use_fast_accum=False,
  91. out_shape=None,
  92. **kwargs,
  93. ) -> int:
  94. """Count flops for _scaled_mm."""
  95. return mm_flop(a_shape, b_shape)
  96. def conv_flop_count(
  97. x_shape: list[int],
  98. w_shape: list[int],
  99. out_shape: list[int],
  100. transposed: bool = False,
  101. ) -> int:
  102. """Count flops for convolution.
  103. Note only multiplication is
  104. counted. Computation for bias are ignored.
  105. Flops for a transposed convolution are calculated as
  106. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  107. Args:
  108. x_shape (list(int)): The input shape before convolution.
  109. w_shape (list(int)): The filter shape.
  110. out_shape (list(int)): The output shape after convolution.
  111. transposed (bool): is the convolution transposed
  112. Returns:
  113. int: the number of flops
  114. """
  115. batch_size = x_shape[0]
  116. conv_shape = (x_shape if transposed else out_shape)[2:]
  117. c_out, c_in, *filter_size = w_shape
  118. """
  119. General idea here is that for a regular conv, for each point in the output
  120. spatial dimension we convolve the filter with something (hence
  121. `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
  122. 1. batch_size, 2. the cross product of input and weight channels.
  123. For the transpose, it's not each point in the *output* spatial dimension but
  124. each point in the *input* spatial dimension.
  125. """
  126. # NB(chilli): I don't think this properly accounts for padding :think:
  127. # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
  128. flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
  129. return flop
  130. @register_flop_formula([aten.convolution,
  131. aten._convolution,
  132. aten.cudnn_convolution,
  133. aten._slow_conv2d_forward,
  134. aten.convolution_overrideable])
  135. def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
  136. """Count flops for convolution."""
  137. # pyrefly: ignore [bad-argument-type]
  138. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  139. @register_flop_formula(aten.convolution_backward)
  140. def conv_backward_flop(
  141. grad_out_shape,
  142. x_shape,
  143. w_shape,
  144. _bias,
  145. _stride,
  146. _padding,
  147. _dilation,
  148. transposed,
  149. _output_padding,
  150. _groups,
  151. output_mask,
  152. out_shape) -> int:
  153. def t(shape):
  154. return [shape[1], shape[0]] + list(shape[2:])
  155. flop_count = 0
  156. """
  157. Let's say we have a regular 1D conv
  158. {A, B, C} [inp]
  159. {i, j} [weight]
  160. => (conv)
  161. {Ai + Bj, Bi + Cj} [out]
  162. And as a reminder, the transposed conv of the above is
  163. => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
  164. For the backwards of conv, we now have
  165. {D, E} [grad_out]
  166. {A, B, C} [inp]
  167. {i, j} [weight]
  168. # grad_inp as conv_transpose(grad_out, weight)
  169. Let's first compute grad_inp. To do so, we can simply look at all the
  170. multiplications that each element of inp is involved in. For example, A is
  171. only involved in the first element of the output (and thus only depends upon
  172. D in grad_out), and C is only involved in the last element of the output
  173. (and thus only depends upon E in grad_out)
  174. {Di, Dj + Ei, Ej} [grad_inp]
  175. Note that this corresponds to the below conv_transpose. This gives us the
  176. output_mask[0] branch, which is grad_inp.
  177. {D, E} [inp (grad_out)]
  178. {i, j} [weight]
  179. => (conv_transpose)
  180. {Di, Dj + Ei, Ej} [out (grad_inp)]
  181. I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
  182. weight) as an exercise for the reader.
  183. # grad_weight as conv(inp, grad_out)
  184. To compute grad_weight, we again look at the terms in the output, which as
  185. a reminder is:
  186. => {Ai + Bj, Bi + Cj} [out]
  187. => {D, E} [grad_out]
  188. If we manually compute the gradient for the weights, we see it's
  189. {AD + BE, BD + CE} [grad_weight]
  190. This corresponds to the below conv
  191. {A, B, C} [inp]
  192. {D, E} [weight (grad_out)]
  193. => (conv)
  194. {AD + BE, BD + CE} [out (grad_weight)]
  195. # grad_weight of transposed conv as conv(grad_out, inp)
  196. As a reminder, the terms of the output of a transposed conv are:
  197. => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
  198. => {D, E, F, G} [grad_out]
  199. Manually computing the gradient for the weights, we see it's
  200. {AD + BE + CF, AE + BF + CG} [grad_weight]
  201. This corresponds to the below conv
  202. {D, E, F, G} [inp (grad_out)]
  203. {A, B, C} [weight (inp)]
  204. => (conv)
  205. {AD + BE + CF, AE + BF + CG} [out (grad_weight)]
  206. For the full backwards formula, there are also some details involving
  207. transpose of the batch/channel dimensions and groups, but I skip those for
  208. the sake of brevity (and they're pretty similar to matmul backwards)
  209. Check [conv backwards decomposition as conv forwards]
  210. """
  211. # grad_inp as conv_transpose(grad_out, weight)
  212. if output_mask[0]:
  213. grad_input_shape = get_shape(out_shape[0])
  214. flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
  215. if output_mask[1]:
  216. grad_weight_shape = get_shape(out_shape[1])
  217. if transposed:
  218. # grad_weight of transposed conv as conv(grad_out, inp)
  219. flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
  220. else:
  221. # grad_weight as conv(inp, grad_out)
  222. flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
  223. return flop_count
  224. def sdpa_flop_count(query_shape, key_shape, value_shape):
  225. """
  226. Count flops for self-attention.
  227. NB: We can assume that value_shape == key_shape
  228. """
  229. b, h, s_q, d_q = query_shape
  230. _b2, _h2, s_k, _d2 = key_shape
  231. _b3, _h3, _s3, d_v = value_shape
  232. if not b == _b2 == _b3 or not h == _h2 == _h3 or not d_q == _d2 or not s_k == _s3 or not d_q == _d2:
  233. raise AssertionError("sdpa_flop_count: query/key/value shapes are incompatible")
  234. total_flops = 0
  235. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  236. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  237. # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
  238. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
  239. return total_flops
  240. @register_flop_formula([aten._scaled_dot_product_efficient_attention,
  241. aten._scaled_dot_product_flash_attention,
  242. aten._scaled_dot_product_cudnn_attention])
  243. def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
  244. """Count flops for self-attention."""
  245. # NB: We aren't accounting for causal attention here
  246. return sdpa_flop_count(query_shape, key_shape, value_shape)
  247. def _offsets_to_lengths(offsets, max_len):
  248. """
  249. If the offsets tensor is fake, then we don't know the actual lengths.
  250. In that case, we can just assume the worst case; each batch has max length.
  251. """
  252. from torch._subclasses.fake_tensor import FakeTensor
  253. from torch._subclasses.functional_tensor import FunctionalTensor
  254. if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
  255. return offsets.diff().tolist()
  256. return [max_len] * (offsets.size(0) - 1)
  257. def _unpack_flash_attention_nested_shapes(
  258. *,
  259. query,
  260. key,
  261. value,
  262. grad_out=None,
  263. cum_seq_q,
  264. cum_seq_k,
  265. max_q,
  266. max_k,
  267. ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]:
  268. """
  269. Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for
  270. NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
  271. each batch element.
  272. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
  273. """
  274. if cum_seq_q is not None:
  275. # This means we should be dealing with a Nested Jagged Tensor query.
  276. # The inputs will have shape (sum(sequence len), heads, dimension)
  277. # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
  278. # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
  279. # So the flops calculation in this case is an overestimate of the actual flops.
  280. if len(key.shape) != 3:
  281. raise AssertionError("sdpa_flop_count: expected key.shape to be 3-dimensional")
  282. if len(value.shape) != 3:
  283. raise AssertionError("sdpa_flop_count: expected value.shape to be 3-dimensional")
  284. if grad_out is not None and grad_out.shape != query.shape:
  285. raise AssertionError("sdpa_flop_count: grad_out.shape must match query.shape when provided")
  286. _, h_q, d_q = query.shape
  287. _, h_k, d_k = key.shape
  288. _, h_v, d_v = value.shape
  289. if cum_seq_q is None:
  290. raise AssertionError("sdpa_flop_count: cum_seq_q must not be None")
  291. if cum_seq_k is None:
  292. raise AssertionError("sdpa_flop_count: cum_seq_k must not be None")
  293. if cum_seq_q.shape != cum_seq_k.shape:
  294. raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
  295. seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
  296. seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
  297. for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths, strict=True):
  298. new_query_shape = (1, h_q, seq_q_len, d_q)
  299. new_key_shape = (1, h_k, seq_k_len, d_k)
  300. new_value_shape = (1, h_v, seq_k_len, d_v)
  301. new_grad_out_shape = new_query_shape if grad_out is not None else None
  302. yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
  303. return
  304. yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
  305. def _unpack_efficient_attention_nested_shapes(
  306. *,
  307. query,
  308. key,
  309. value,
  310. grad_out=None,
  311. cu_seqlens_q,
  312. cu_seqlens_k,
  313. max_seqlen_q,
  314. max_seqlen_k,
  315. ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]:
  316. """
  317. Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for
  318. NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
  319. each batch element.
  320. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
  321. """
  322. if cu_seqlens_q is not None:
  323. # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention.
  324. #
  325. # This means we should be dealing with a Nested Jagged Tensor query.
  326. # The inputs will have shape (sum(sequence len), heads, dimension)
  327. # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
  328. # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
  329. # So the flops calculation in this case is an overestimate of the actual flops.
  330. if len(key.shape) != 4:
  331. raise AssertionError("_unpack_efficient_attention_nested_shapes: expected key.shape to be 4-dimensional")
  332. if len(value.shape) != 4:
  333. raise AssertionError("_unpack_efficient_attention_nested_shapes: expected value.shape to be 4-dimensional")
  334. if grad_out is not None and grad_out.shape != query.shape:
  335. raise AssertionError("_unpack_efficient_attention_nested_shapes: grad_out.shape must match query.shape when provided")
  336. _, _, h_q, d_q = query.shape
  337. _, _, h_k, d_k = key.shape
  338. _, _, h_v, d_v = value.shape
  339. if cu_seqlens_q is None:
  340. raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_q must not be None")
  341. if cu_seqlens_k is None:
  342. raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_k must not be None")
  343. if cu_seqlens_q.shape != cu_seqlens_k.shape:
  344. raise AssertionError("_unpack_efficient_attention_nested_shapes: "
  345. "cu_seqlens_q and cu_seqlens_k must have the same shape")
  346. seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
  347. seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
  348. for len_q, len_k in zip(seqlens_q, seqlens_k, strict=True):
  349. new_query_shape = (1, h_q, len_q, d_q)
  350. new_key_shape = (1, h_k, len_k, d_k)
  351. new_value_shape = (1, h_v, len_k, d_v)
  352. new_grad_out_shape = new_query_shape if grad_out is not None else None
  353. yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
  354. return
  355. yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
  356. @register_flop_formula(aten._flash_attention_forward, get_raw=True)
  357. def _flash_attention_forward_flop(
  358. query,
  359. key,
  360. value,
  361. cum_seq_q,
  362. cum_seq_k,
  363. max_q,
  364. max_k,
  365. *args,
  366. out_shape=None,
  367. **kwargs
  368. ) -> int:
  369. """Count flops for self-attention."""
  370. # NB: We aren't accounting for causal attention here
  371. # in case this is a nested tensor, we unpack the individual batch elements
  372. # and then sum the flops per batch element
  373. sizes = _unpack_flash_attention_nested_shapes(
  374. query=query,
  375. key=key,
  376. value=value,
  377. cum_seq_q=cum_seq_q,
  378. cum_seq_k=cum_seq_k,
  379. max_q=max_q,
  380. max_k=max_k,
  381. )
  382. return sum(
  383. sdpa_flop_count(query_shape, key_shape, value_shape)
  384. for query_shape, key_shape, value_shape, _ in sizes
  385. )
  386. @register_flop_formula(aten._efficient_attention_forward, get_raw=True)
  387. def _efficient_attention_forward_flop(
  388. query,
  389. key,
  390. value,
  391. bias,
  392. cu_seqlens_q,
  393. cu_seqlens_k,
  394. max_seqlen_q,
  395. max_seqlen_k,
  396. *args,
  397. **kwargs
  398. ) -> int:
  399. """Count flops for self-attention."""
  400. # NB: We aren't accounting for causal attention here
  401. # in case this is a nested tensor, we unpack the individual batch elements
  402. # and then sum the flops per batch element
  403. sizes = _unpack_efficient_attention_nested_shapes(
  404. query=query,
  405. key=key,
  406. value=value,
  407. cu_seqlens_q=cu_seqlens_q,
  408. cu_seqlens_k=cu_seqlens_k,
  409. max_seqlen_q=max_seqlen_q,
  410. max_seqlen_k=max_seqlen_k,
  411. )
  412. return sum(
  413. sdpa_flop_count(query_shape, key_shape, value_shape)
  414. for query_shape, key_shape, value_shape, _ in sizes
  415. )
  416. def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
  417. total_flops = 0
  418. b, h, s_q, d_q = query_shape
  419. _b2, _h2, s_k, _d2 = key_shape
  420. _b3, _h3, _s3, d_v = value_shape
  421. _b4, _h4, _s4, _d4 = grad_out_shape
  422. if not b == _b2 == _b3 == _b4 or not h == _h2 == _h3 == _h4 or not d_q == _d2:
  423. raise AssertionError("sdpa_backward_flop_count: batch/heads/dimension mismatch among tensors")
  424. if not d_v == _d4 or not s_k == _s3 or not s_q == _s4:
  425. raise AssertionError("sdpa_backward_flop_count: grad_out/value/key/query shapes are incompatible")
  426. total_flops = 0
  427. # Step 1: We recompute the scores matrix.
  428. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  429. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  430. # Step 2: We propagate the gradients through the score @ v operation.
  431. # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
  432. total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
  433. # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
  434. total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
  435. # Step 3: We propagate th gradients through the k @ v operation
  436. # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
  437. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
  438. # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
  439. total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
  440. return total_flops
  441. @register_flop_formula([aten._scaled_dot_product_efficient_attention_backward,
  442. aten._scaled_dot_product_flash_attention_backward,
  443. aten._scaled_dot_product_cudnn_attention_backward])
  444. def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
  445. """Count flops for self-attention backward."""
  446. return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  447. @register_flop_formula(aten._flash_attention_backward, get_raw=True)
  448. def _flash_attention_backward_flop(
  449. grad_out,
  450. query,
  451. key,
  452. value,
  453. out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper
  454. logsumexp,
  455. cum_seq_q,
  456. cum_seq_k,
  457. max_q,
  458. max_k,
  459. *args,
  460. **kwargs,
  461. ) -> int:
  462. # in case this is a nested tensor, we unpack the individual batch elements
  463. # and then sum the flops per batch element
  464. shapes = _unpack_flash_attention_nested_shapes(
  465. query=query,
  466. key=key,
  467. value=value,
  468. grad_out=grad_out,
  469. cum_seq_q=cum_seq_q,
  470. cum_seq_k=cum_seq_k,
  471. max_q=max_q,
  472. max_k=max_k,
  473. )
  474. return sum(
  475. sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  476. for query_shape, key_shape, value_shape, grad_out_shape in shapes
  477. )
  478. @register_flop_formula(aten._efficient_attention_backward, get_raw=True)
  479. def _efficient_attention_backward_flop(
  480. grad_out,
  481. query,
  482. key,
  483. value,
  484. bias,
  485. out, # named _out to avoid kwarg collision with out created in wrapper
  486. cu_seqlens_q,
  487. cu_seqlens_k,
  488. max_seqlen_q,
  489. max_seqlen_k,
  490. *args,
  491. **kwargs,
  492. ) -> int:
  493. # in case this is a nested tensor, we unpack the individual batch elements
  494. # and then sum the flops per batch element
  495. shapes = _unpack_efficient_attention_nested_shapes(
  496. query=query,
  497. key=key,
  498. value=value,
  499. grad_out=grad_out,
  500. cu_seqlens_q=cu_seqlens_q,
  501. cu_seqlens_k=cu_seqlens_k,
  502. max_seqlen_q=max_seqlen_q,
  503. max_seqlen_k=max_seqlen_k,
  504. )
  505. return sum(
  506. sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  507. for query_shape, key_shape, value_shape, grad_out_shape in shapes
  508. )
  509. flop_registry = {
  510. aten.mm: mm_flop,
  511. aten.addmm: addmm_flop,
  512. aten.bmm: bmm_flop,
  513. aten.baddbmm: baddbmm_flop,
  514. aten._scaled_mm: _scaled_mm_flop,
  515. aten.convolution: conv_flop,
  516. aten._convolution: conv_flop,
  517. aten.cudnn_convolution: conv_flop,
  518. aten.convolution_overrideable: conv_flop,
  519. aten._slow_conv2d_forward: conv_flop,
  520. aten.convolution_backward: conv_backward_flop,
  521. aten._scaled_dot_product_efficient_attention: sdpa_flop,
  522. aten._scaled_dot_product_flash_attention: sdpa_flop,
  523. aten._scaled_dot_product_cudnn_attention: sdpa_flop,
  524. aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
  525. aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
  526. aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop,
  527. aten._flash_attention_forward: _flash_attention_forward_flop,
  528. aten._efficient_attention_forward: _efficient_attention_forward_flop,
  529. aten._flash_attention_backward: _flash_attention_backward_flop,
  530. aten._efficient_attention_backward: _efficient_attention_backward_flop,
  531. }
  532. def normalize_tuple(x):
  533. if not isinstance(x, tuple):
  534. return (x,)
  535. return x
  536. # Define the suffixes for different orders of magnitude
  537. suffixes = ["", "K", "M", "B", "T"]
  538. # Thanks BingChat!
  539. def get_suffix_str(number):
  540. # Find the index of the appropriate suffix based on the number of digits
  541. # with some additional overflow.
  542. # i.e. 1.01B should be displayed as 1001M, not 1.001B
  543. index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
  544. return suffixes[index]
  545. def convert_num_with_suffix(number, suffix):
  546. index = suffixes.index(suffix)
  547. # Divide the number by 1000^index and format it to two decimal places
  548. value = f"{number / 1000 ** index:.3f}"
  549. # Return the value and the suffix as a string
  550. return value + suffixes[index]
  551. def convert_to_percent_str(num, denom) -> str:
  552. if denom == 0:
  553. return "0%"
  554. return f"{num / denom:.2%}"
  555. def _pytreeify_preserve_structure(f):
  556. @wraps(f)
  557. def nf(args):
  558. flat_args, spec = tree_flatten(args)
  559. out = f(*flat_args)
  560. return tree_unflatten(out, spec)
  561. return nf
  562. class FlopCounterMode:
  563. """
  564. ``FlopCounterMode`` is a context manager that counts the number of flops within its context.
  565. It does this using a ``TorchDispatchMode``.
  566. It also supports hierarchical output by passing a module (or list of
  567. modules) to FlopCounterMode on construction. If you do not need hierarchical
  568. output, you do not need to use it with a module.
  569. Example usage
  570. .. code-block:: python
  571. mod = ...
  572. with FlopCounterMode(mod) as flop_counter:
  573. mod.sum().backward()
  574. """
  575. def __init__(
  576. self,
  577. mods: torch.nn.Module | list[torch.nn.Module] | None = None,
  578. depth: int = 2,
  579. display: bool = True,
  580. custom_mapping: dict[Any, Any] | None = None) -> None:
  581. super().__init__()
  582. self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
  583. self.depth = depth
  584. self.display = display
  585. self.mode: _FlopCounterMode | None = None
  586. if custom_mapping is None:
  587. custom_mapping = {}
  588. if mods is not None:
  589. warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
  590. self.flop_registry = {
  591. **flop_registry,
  592. **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
  593. }
  594. self.mod_tracker = ModuleTracker()
  595. def get_total_flops(self) -> int:
  596. return sum(self.flop_counts['Global'].values())
  597. def get_flop_counts(self) -> dict[str, dict[Any, int]]:
  598. """Return the flop counts as a dictionary of dictionaries.
  599. The outer
  600. dictionary is keyed by module name, and the inner dictionary is keyed by
  601. operation name.
  602. Returns:
  603. Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
  604. """
  605. return {k: dict(v) for k, v in self.flop_counts.items()}
  606. def get_table(self, depth=None):
  607. if depth is None:
  608. depth = self.depth
  609. if depth is None:
  610. depth = 999999
  611. import tabulate
  612. tabulate.PRESERVE_WHITESPACE = True
  613. header = ["Module", "FLOP", "% Total"]
  614. values = []
  615. global_flops = self.get_total_flops()
  616. global_suffix = get_suffix_str(global_flops)
  617. is_global_subsumed = False
  618. def process_mod(mod_name, depth):
  619. nonlocal is_global_subsumed
  620. total_flops = sum(self.flop_counts[mod_name].values())
  621. is_global_subsumed |= total_flops >= global_flops
  622. padding = " " * depth
  623. values = []
  624. values.append([
  625. padding + mod_name,
  626. convert_num_with_suffix(total_flops, global_suffix),
  627. convert_to_percent_str(total_flops, global_flops)
  628. ])
  629. for k, v in self.flop_counts[mod_name].items():
  630. values.append([
  631. padding + " - " + str(k),
  632. convert_num_with_suffix(v, global_suffix),
  633. convert_to_percent_str(v, global_flops)
  634. ])
  635. return values
  636. for mod in sorted(self.flop_counts.keys()):
  637. if mod == 'Global':
  638. continue
  639. mod_depth = mod.count(".") + 1
  640. if mod_depth > depth:
  641. continue
  642. cur_values = process_mod(mod, mod_depth - 1)
  643. values.extend(cur_values)
  644. # We do a bit of messing around here to only output the "Global" value
  645. # if there are any FLOPs in there that aren't already fully contained by
  646. # a module.
  647. if 'Global' in self.flop_counts and not is_global_subsumed:
  648. for value in values:
  649. value[0] = " " + value[0]
  650. values = process_mod('Global', 0) + values
  651. if len(values) == 0:
  652. values = [["Global", "0", "0%"]]
  653. return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
  654. # NB: This context manager is NOT reentrant
  655. def __enter__(self):
  656. self.flop_counts.clear()
  657. self.mod_tracker.__enter__()
  658. self.mode = _FlopCounterMode(self)
  659. self.mode.__enter__()
  660. return self
  661. def __exit__(self, *args):
  662. if self.mode is None:
  663. raise AssertionError("Internal error: FlopCounter.__exit__ called but mode is None")
  664. b = self.mode.__exit__(*args)
  665. self.mode = None # break cycles
  666. self.mod_tracker.__exit__()
  667. if self.display:
  668. print(self.get_table(self.depth))
  669. return b
  670. def _count_flops(self, func_packet, out, args, kwargs):
  671. if func_packet in self.flop_registry:
  672. flop_count_func = self.flop_registry[func_packet]
  673. flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
  674. for par in set(self.mod_tracker.parents):
  675. self.flop_counts[par][func_packet] += flop_count
  676. return out
  677. class _FlopCounterMode(TorchDispatchMode):
  678. supports_higher_order_operators = True
  679. def __init__(self, counter: FlopCounterMode) -> None:
  680. self.counter = counter
  681. def _execute_with_isolated_flop_counting(self, branch_fn, operands):
  682. """Execute a branch function and capture its FLOP counts without
  683. affecting self.counter.flop_counts
  684. Args:
  685. branch_fn: The branch function to execute
  686. operands: Arguments to pass to the branch function
  687. Returns:
  688. Tuple of (result, flop_counts) where result is the branch output
  689. and flop_counts is a copy of the FLOP counts after execution
  690. """
  691. import copy
  692. checkpointed_flop_counts = copy.copy(self.counter.flop_counts)
  693. with self:
  694. result = branch_fn(*operands)
  695. flop_counts = copy.copy(self.counter.flop_counts)
  696. self.counter.flop_counts = checkpointed_flop_counts
  697. return result, flop_counts
  698. def _handle_higher_order_ops(self, func, types, args, kwargs):
  699. if func is not torch.ops.higher_order.cond:
  700. return NotImplemented
  701. # The flop counter for cond counts the upper bound of flops.
  702. # For example, if a matmul is executed 2 times in true branch
  703. # but only 1 time in the false branch, the flop counter will
  704. # record the larger number of flops, i.e. 2 times.
  705. if func is torch.ops.higher_order.cond:
  706. pred, true_branch, false_branch, operands = args
  707. # Step 1: Count flops for true branch and false branch separately
  708. true_out, true_flop_counts = self._execute_with_isolated_flop_counting(
  709. true_branch, operands
  710. )
  711. if true_out is NotImplemented:
  712. return NotImplemented
  713. false_out, false_flop_counts = self._execute_with_isolated_flop_counting(
  714. false_branch, operands
  715. )
  716. if false_out is NotImplemented:
  717. return NotImplemented
  718. # Step 2: merge flop counts
  719. all_mod_keys = set(true_flop_counts.keys()) | set(false_flop_counts.keys())
  720. merged_flop_counts = {}
  721. for outer_key in all_mod_keys:
  722. true_func_counts = true_flop_counts[outer_key]
  723. false_func_counts = false_flop_counts[outer_key]
  724. merged_func_counts = {}
  725. all_func_keys = set(true_func_counts.keys()) | set(false_func_counts.keys())
  726. for func_key in all_func_keys:
  727. true_val = true_func_counts.get(func_key, 0)
  728. false_val = false_func_counts.get(func_key, 0)
  729. merged_func_counts[func_key] = max(true_val, false_val)
  730. merged_flop_counts[outer_key] = merged_func_counts
  731. # Step 3: update the counter with merged counts
  732. for outer_key, inner_dict in merged_flop_counts.items():
  733. self.counter.flop_counts[outer_key].update(inner_dict)
  734. # It doesn't matter which one we return since true_fn and false_fn return
  735. # output with the same structure.
  736. return true_out
  737. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  738. kwargs = kwargs if kwargs else {}
  739. # Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
  740. if func in {torch.ops.aten.sym_is_contiguous.default,
  741. torch.ops.aten.is_contiguous.default,
  742. torch.ops.aten.is_contiguous.memory_format,
  743. torch.ops.aten.is_strides_like_format.default,
  744. torch.ops.aten.is_non_overlapping_and_dense.default,
  745. torch.ops.aten.size.default,
  746. torch.ops.aten.sym_size.default,
  747. torch.ops.aten.stride.default,
  748. torch.ops.aten.sym_stride.default,
  749. torch.ops.aten.storage_offset.default,
  750. torch.ops.aten.sym_storage_offset.default,
  751. torch.ops.aten.numel.default,
  752. torch.ops.aten.sym_numel.default,
  753. torch.ops.aten.dim.default,
  754. torch.ops.prim.layout.default}:
  755. return NotImplemented
  756. if isinstance(func, torch._ops.HigherOrderOperator):
  757. return self._handle_higher_order_ops(func, types, args, kwargs)
  758. # If we don't have func in flop_registry, see if it can decompose
  759. if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default:
  760. with self:
  761. r = func.decompose(*args, **kwargs)
  762. if r is not NotImplemented:
  763. return r
  764. # no further decomposition; execute & count flops
  765. out = func(*args, **kwargs)
  766. return self.counter._count_flops(func._overloadpacket, out, args, kwargs)