flop_counter.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
  4. from .module_tracker import ModuleTracker
  5. from typing import Any, Optional, Union, TypeVar, Callable
  6. from collections.abc import Iterator
  7. from typing_extensions import ParamSpec
  8. from collections import defaultdict
  9. from torch.utils._python_dispatch import TorchDispatchMode
  10. from math import prod
  11. from functools import wraps
  12. import warnings
  13. __all__ = ["FlopCounterMode", "register_flop_formula"]
  14. _T = TypeVar("_T")
  15. _P = ParamSpec("_P")
  16. aten = torch.ops.aten
  17. def get_shape(i):
  18. if isinstance(i, torch.Tensor):
  19. return i.shape
  20. return i
  21. flop_registry: dict[Any, Any] = {}
  22. def shape_wrapper(f):
  23. @wraps(f)
  24. def nf(*args, out_val=None, **kwargs):
  25. args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val))
  26. return f(*args, out_shape=out_shape, **kwargs)
  27. return nf
  28. def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  29. def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
  30. if not get_raw:
  31. flop_formula = shape_wrapper(flop_formula)
  32. def register(target):
  33. if not isinstance(target, torch._ops.OpOverloadPacket):
  34. raise ValueError(
  35. f"register_flop_formula(targets): expected each target to be "
  36. f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got "
  37. f"{target} which is of type {type(target)}")
  38. if target in flop_registry:
  39. raise RuntimeError(f"duplicate registrations for {target}")
  40. flop_registry[target] = flop_formula
  41. # To handle allowing multiple aten_ops at once
  42. torch.utils._pytree.tree_map_(register, targets)
  43. return flop_formula
  44. return register_fun
  45. @register_flop_formula(aten.mm)
  46. def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
  47. """Count flops for matmul."""
  48. # Inputs should be a list of length 2.
  49. # Inputs contains the shapes of two matrices.
  50. m, k = a_shape
  51. k2, n = b_shape
  52. assert k == k2
  53. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  54. return m * n * 2 * k
  55. @register_flop_formula(aten.addmm)
  56. def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  57. """Count flops for addmm."""
  58. return mm_flop(a_shape, b_shape)
  59. @register_flop_formula(aten.bmm)
  60. def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
  61. """Count flops for the bmm operation."""
  62. # Inputs should be a list of length 2.
  63. # Inputs contains the shapes of two tensor.
  64. b, m, k = a_shape
  65. b2, k2, n = b_shape
  66. assert b == b2
  67. assert k == k2
  68. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  69. flop = b * m * n * 2 * k
  70. return flop
  71. @register_flop_formula(aten.baddbmm)
  72. def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  73. """Count flops for the baddbmm operation."""
  74. # Inputs should be a list of length 3.
  75. # Inputs contains the shapes of three tensors.
  76. return bmm_flop(a_shape, b_shape)
  77. @register_flop_formula(aten._scaled_mm)
  78. def _scaled_mm_flop(
  79. a_shape,
  80. b_shape,
  81. scale_a_shape,
  82. scale_b_shape,
  83. bias_shape=None,
  84. scale_result_shape=None,
  85. out_dtype=None,
  86. use_fast_accum=False,
  87. out_shape=None,
  88. **kwargs,
  89. ) -> int:
  90. """Count flops for _scaled_mm."""
  91. return mm_flop(a_shape, b_shape)
  92. def conv_flop_count(
  93. x_shape: list[int],
  94. w_shape: list[int],
  95. out_shape: list[int],
  96. transposed: bool = False,
  97. ) -> int:
  98. """Count flops for convolution.
  99. Note only multiplication is
  100. counted. Computation for bias are ignored.
  101. Flops for a transposed convolution are calculated as
  102. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  103. Args:
  104. x_shape (list(int)): The input shape before convolution.
  105. w_shape (list(int)): The filter shape.
  106. out_shape (list(int)): The output shape after convolution.
  107. transposed (bool): is the convolution transposed
  108. Returns:
  109. int: the number of flops
  110. """
  111. batch_size = x_shape[0]
  112. conv_shape = (x_shape if transposed else out_shape)[2:]
  113. c_out, c_in, *filter_size = w_shape
  114. """
  115. General idea here is that for a regular conv, for each point in the output
  116. spatial dimension we convolve the filter with something (hence
  117. `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
  118. 1. batch_size, 2. the cross product of input and weight channels.
  119. For the transpose, it's not each point in the *output* spatial dimension but
  120. each point in the *input* spatial dimension.
  121. """
  122. # NB(chilli): I don't think this properly accounts for padding :think:
  123. # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
  124. flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
  125. return flop
  126. @register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
  127. def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
  128. """Count flops for convolution."""
  129. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  130. @register_flop_formula(aten.convolution_backward)
  131. def conv_backward_flop(
  132. grad_out_shape,
  133. x_shape,
  134. w_shape,
  135. _bias,
  136. _stride,
  137. _padding,
  138. _dilation,
  139. transposed,
  140. _output_padding,
  141. _groups,
  142. output_mask,
  143. out_shape) -> int:
  144. def t(shape):
  145. return [shape[1], shape[0]] + list(shape[2:])
  146. flop_count = 0
  147. """
  148. Let's say we have a regular 1D conv
  149. {A, B, C} [inp]
  150. {i, j} [weight]
  151. => (conv)
  152. {Ai + Bj, Bi + Cj} [out]
  153. And as a reminder, the transposed conv of the above is
  154. => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
  155. For the backwards of conv, we now have
  156. {D, E} [grad_out]
  157. {A, B, C} [inp]
  158. {i, j} [weight]
  159. # grad_inp as conv_transpose(grad_out, weight)
  160. Let's first compute grad_inp. To do so, we can simply look at all the
  161. multiplications that each element of inp is involved in. For example, A is
  162. only involved in the first element of the output (and thus only depends upon
  163. D in grad_out), and C is only involved in the last element of the output
  164. (and thus only depends upon E in grad_out)
  165. {Di, Dj + Ei, Ej} [grad_inp]
  166. Note that this corresponds to the below conv_transpose. This gives us the
  167. output_mask[0] branch, which is grad_inp.
  168. {D, E} [inp (grad_out)]
  169. {i, j} [weight]
  170. => (conv_transpose)
  171. {Di, Dj + Ei, Ej} [out (grad_inp)]
  172. I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
  173. weight) as an exercise for the reader.
  174. # grad_weight as conv(inp, grad_out)
  175. To compute grad_weight, we again look at the terms in the output, which as
  176. a reminder is:
  177. => {Ai + Bj, Bi + Cj} [out]
  178. => {D, E} [grad_out]
  179. If we manually compute the gradient for the weights, we see it's
  180. {AD + BE, BD + CE} [grad_weight]
  181. This corresponds to the below conv
  182. {A, B, C} [inp]
  183. {D, E} [weight (grad_out)]
  184. => (conv)
  185. {AD + BE, BD + CE} [out (grad_weight)]
  186. # grad_weight of transposed conv as conv(grad_out, inp)
  187. As a reminder, the terms of the output of a transposed conv are:
  188. => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
  189. => {D, E, F, G} [grad_out]
  190. Manually computing the gradient for the weights, we see it's
  191. {AD + BE + CF, AE + BF + CG} [grad_weight]
  192. This corresponds to the below conv
  193. {D, E, F, G} [inp (grad_out)]
  194. {A, B, C} [weight (inp)]
  195. => (conv)
  196. {AD + BE + CF, AE + BF + CG} [out (grad_weight)]
  197. For the full backwards formula, there are also some details involving
  198. transpose of the batch/channel dimensions and groups, but I skip those for
  199. the sake of brevity (and they're pretty similar to matmul backwards)
  200. Check [conv backwards decomposition as conv forwards]
  201. """
  202. # grad_inp as conv_transpose(grad_out, weight)
  203. if output_mask[0]:
  204. grad_input_shape = get_shape(out_shape[0])
  205. flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
  206. if output_mask[1]:
  207. grad_weight_shape = get_shape(out_shape[1])
  208. if transposed:
  209. # grad_weight of transposed conv as conv(grad_out, inp)
  210. flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
  211. else:
  212. # grad_weight as conv(inp, grad_out)
  213. flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
  214. return flop_count
  215. def sdpa_flop_count(query_shape, key_shape, value_shape):
  216. """
  217. Count flops for self-attention.
  218. NB: We can assume that value_shape == key_shape
  219. """
  220. b, h, s_q, d_q = query_shape
  221. _b2, _h2, s_k, _d2 = key_shape
  222. _b3, _h3, _s3, d_v = value_shape
  223. assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
  224. total_flops = 0
  225. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  226. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  227. # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
  228. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
  229. return total_flops
  230. @register_flop_formula([aten._scaled_dot_product_efficient_attention,
  231. aten._scaled_dot_product_flash_attention,
  232. aten._scaled_dot_product_cudnn_attention])
  233. def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
  234. """Count flops for self-attention."""
  235. # NB: We aren't accounting for causal attention here
  236. return sdpa_flop_count(query_shape, key_shape, value_shape)
  237. def _offsets_to_lengths(offsets, max_len):
  238. """
  239. If the offsets tensor is fake, then we don't know the actual lengths.
  240. In that case, we can just assume the worst case; each batch has max length.
  241. """
  242. from torch._subclasses.fake_tensor import FakeTensor
  243. from torch._subclasses.functional_tensor import FunctionalTensor
  244. if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
  245. return offsets.diff().tolist()
  246. return [max_len] * (offsets.size(0) - 1)
  247. def _unpack_flash_attention_nested_shapes(
  248. *,
  249. query,
  250. key,
  251. value,
  252. grad_out=None,
  253. cum_seq_q,
  254. cum_seq_k,
  255. max_q,
  256. max_k,
  257. ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]:
  258. """
  259. Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for
  260. NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
  261. each batch element.
  262. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
  263. """
  264. if cum_seq_q is not None:
  265. # This means we should be dealing with a Nested Jagged Tensor query.
  266. # The inputs will have shape (sum(sequence len), heads, dimension)
  267. # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
  268. # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
  269. # So the flops calculation in this case is an overestimate of the actual flops.
  270. assert len(key.shape) == 3
  271. assert len(value.shape) == 3
  272. assert grad_out is None or grad_out.shape == query.shape
  273. _, h_q, d_q = query.shape
  274. _, h_k, d_k = key.shape
  275. _, h_v, d_v = value.shape
  276. assert cum_seq_q is not None
  277. assert cum_seq_k is not None
  278. assert cum_seq_q.shape == cum_seq_k.shape
  279. seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
  280. seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
  281. for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
  282. new_query_shape = (1, h_q, seq_q_len, d_q)
  283. new_key_shape = (1, h_k, seq_k_len, d_k)
  284. new_value_shape = (1, h_v, seq_k_len, d_v)
  285. new_grad_out_shape = new_query_shape if grad_out is not None else None
  286. yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
  287. return
  288. yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
  289. def _unpack_efficient_attention_nested_shapes(
  290. *,
  291. query,
  292. key,
  293. value,
  294. grad_out=None,
  295. cu_seqlens_q,
  296. cu_seqlens_k,
  297. max_seqlen_q,
  298. max_seqlen_k,
  299. ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]:
  300. """
  301. Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for
  302. NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
  303. each batch element.
  304. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
  305. """
  306. if cu_seqlens_q is not None:
  307. # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention.
  308. #
  309. # This means we should be dealing with a Nested Jagged Tensor query.
  310. # The inputs will have shape (sum(sequence len), heads, dimension)
  311. # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
  312. # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
  313. # So the flops calculation in this case is an overestimate of the actual flops.
  314. assert len(key.shape) == 4
  315. assert len(value.shape) == 4
  316. assert grad_out is None or grad_out.shape == query.shape
  317. _, _, h_q, d_q = query.shape
  318. _, _, h_k, d_k = key.shape
  319. _, _, h_v, d_v = value.shape
  320. assert cu_seqlens_q is not None
  321. assert cu_seqlens_k is not None
  322. assert cu_seqlens_q.shape == cu_seqlens_k.shape
  323. seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
  324. seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
  325. for len_q, len_k in zip(seqlens_q, seqlens_k):
  326. new_query_shape = (1, h_q, len_q, d_q)
  327. new_key_shape = (1, h_k, len_k, d_k)
  328. new_value_shape = (1, h_v, len_k, d_v)
  329. new_grad_out_shape = new_query_shape if grad_out is not None else None
  330. yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
  331. return
  332. yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
  333. @register_flop_formula(aten._flash_attention_forward, get_raw=True)
  334. def _flash_attention_forward_flop(
  335. query,
  336. key,
  337. value,
  338. cum_seq_q,
  339. cum_seq_k,
  340. max_q,
  341. max_k,
  342. *args,
  343. out_shape=None,
  344. **kwargs
  345. ) -> int:
  346. """Count flops for self-attention."""
  347. # NB: We aren't accounting for causal attention here
  348. # in case this is a nested tensor, we unpack the individual batch elements
  349. # and then sum the flops per batch element
  350. sizes = _unpack_flash_attention_nested_shapes(
  351. query=query,
  352. key=key,
  353. value=value,
  354. cum_seq_q=cum_seq_q,
  355. cum_seq_k=cum_seq_k,
  356. max_q=max_q,
  357. max_k=max_k,
  358. )
  359. return sum(
  360. sdpa_flop_count(query_shape, key_shape, value_shape)
  361. for query_shape, key_shape, value_shape, _ in sizes
  362. )
  363. @register_flop_formula(aten._efficient_attention_forward, get_raw=True)
  364. def _efficient_attention_forward_flop(
  365. query,
  366. key,
  367. value,
  368. bias,
  369. cu_seqlens_q,
  370. cu_seqlens_k,
  371. max_seqlen_q,
  372. max_seqlen_k,
  373. *args,
  374. **kwargs
  375. ) -> int:
  376. """Count flops for self-attention."""
  377. # NB: We aren't accounting for causal attention here
  378. # in case this is a nested tensor, we unpack the individual batch elements
  379. # and then sum the flops per batch element
  380. sizes = _unpack_efficient_attention_nested_shapes(
  381. query=query,
  382. key=key,
  383. value=value,
  384. cu_seqlens_q=cu_seqlens_q,
  385. cu_seqlens_k=cu_seqlens_k,
  386. max_seqlen_q=max_seqlen_q,
  387. max_seqlen_k=max_seqlen_k,
  388. )
  389. return sum(
  390. sdpa_flop_count(query_shape, key_shape, value_shape)
  391. for query_shape, key_shape, value_shape, _ in sizes
  392. )
  393. def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
  394. total_flops = 0
  395. b, h, s_q, d_q = query_shape
  396. _b2, _h2, s_k, _d2 = key_shape
  397. _b3, _h3, _s3, d_v = value_shape
  398. _b4, _h4, _s4, _d4 = grad_out_shape
  399. assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
  400. assert d_v == _d4 and s_k == _s3 and s_q == _s4
  401. total_flops = 0
  402. # Step 1: We recompute the scores matrix.
  403. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  404. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  405. # Step 2: We propagate the gradients through the score @ v operation.
  406. # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
  407. total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
  408. # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
  409. total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
  410. # Step 3: We propagate th gradients through the k @ v operation
  411. # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
  412. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
  413. # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
  414. total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
  415. return total_flops
  416. @register_flop_formula([aten._scaled_dot_product_efficient_attention_backward,
  417. aten._scaled_dot_product_flash_attention_backward,
  418. aten._scaled_dot_product_cudnn_attention_backward])
  419. def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
  420. """Count flops for self-attention backward."""
  421. return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  422. @register_flop_formula(aten._flash_attention_backward, get_raw=True)
  423. def _flash_attention_backward_flop(
  424. grad_out,
  425. query,
  426. key,
  427. value,
  428. out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper
  429. logsumexp,
  430. cum_seq_q,
  431. cum_seq_k,
  432. max_q,
  433. max_k,
  434. *args,
  435. **kwargs,
  436. ) -> int:
  437. # in case this is a nested tensor, we unpack the individual batch elements
  438. # and then sum the flops per batch element
  439. shapes = _unpack_flash_attention_nested_shapes(
  440. query=query,
  441. key=key,
  442. value=value,
  443. grad_out=grad_out,
  444. cum_seq_q=cum_seq_q,
  445. cum_seq_k=cum_seq_k,
  446. max_q=max_q,
  447. max_k=max_k,
  448. )
  449. return sum(
  450. sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  451. for query_shape, key_shape, value_shape, grad_out_shape in shapes
  452. )
  453. @register_flop_formula(aten._efficient_attention_backward, get_raw=True)
  454. def _efficient_attention_backward_flop(
  455. grad_out,
  456. query,
  457. key,
  458. value,
  459. bias,
  460. out, # named _out to avoid kwarg collision with out created in wrapper
  461. cu_seqlens_q,
  462. cu_seqlens_k,
  463. max_seqlen_q,
  464. max_seqlen_k,
  465. *args,
  466. **kwargs,
  467. ) -> int:
  468. # in case this is a nested tensor, we unpack the individual batch elements
  469. # and then sum the flops per batch element
  470. shapes = _unpack_efficient_attention_nested_shapes(
  471. query=query,
  472. key=key,
  473. value=value,
  474. grad_out=grad_out,
  475. cu_seqlens_q=cu_seqlens_q,
  476. cu_seqlens_k=cu_seqlens_k,
  477. max_seqlen_q=max_seqlen_q,
  478. max_seqlen_k=max_seqlen_k,
  479. )
  480. return sum(
  481. sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  482. for query_shape, key_shape, value_shape, grad_out_shape in shapes
  483. )
  484. flop_registry = {
  485. aten.mm: mm_flop,
  486. aten.addmm: addmm_flop,
  487. aten.bmm: bmm_flop,
  488. aten.baddbmm: baddbmm_flop,
  489. aten._scaled_mm: _scaled_mm_flop,
  490. aten.convolution: conv_flop,
  491. aten._convolution: conv_flop,
  492. aten.cudnn_convolution: conv_flop,
  493. aten._slow_conv2d_forward: conv_flop,
  494. aten.convolution_backward: conv_backward_flop,
  495. aten._scaled_dot_product_efficient_attention: sdpa_flop,
  496. aten._scaled_dot_product_flash_attention: sdpa_flop,
  497. aten._scaled_dot_product_cudnn_attention: sdpa_flop,
  498. aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
  499. aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
  500. aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop,
  501. aten._flash_attention_forward: _flash_attention_forward_flop,
  502. aten._efficient_attention_forward: _efficient_attention_forward_flop,
  503. aten._flash_attention_backward: _flash_attention_backward_flop,
  504. aten._efficient_attention_backward: _efficient_attention_backward_flop,
  505. }
  506. def normalize_tuple(x):
  507. if not isinstance(x, tuple):
  508. return (x,)
  509. return x
  510. # Define the suffixes for different orders of magnitude
  511. suffixes = ["", "K", "M", "B", "T"]
  512. # Thanks BingChat!
  513. def get_suffix_str(number):
  514. # Find the index of the appropriate suffix based on the number of digits
  515. # with some additional overflow.
  516. # i.e. 1.01B should be displayed as 1001M, not 1.001B
  517. index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
  518. return suffixes[index]
  519. def convert_num_with_suffix(number, suffix):
  520. index = suffixes.index(suffix)
  521. # Divide the number by 1000^index and format it to two decimal places
  522. value = f"{number / 1000 ** index:.3f}"
  523. # Return the value and the suffix as a string
  524. return value + suffixes[index]
  525. def convert_to_percent_str(num, denom):
  526. if denom == 0:
  527. return "0%"
  528. return f"{num / denom:.2%}"
  529. def _pytreeify_preserve_structure(f):
  530. @wraps(f)
  531. def nf(args):
  532. flat_args, spec = tree_flatten(args)
  533. out = f(*flat_args)
  534. return tree_unflatten(out, spec)
  535. return nf
  536. class FlopCounterMode:
  537. """
  538. ``FlopCounterMode`` is a context manager that counts the number of flops within its context.
  539. It does this using a ``TorchDispatchMode``.
  540. It also supports hierarchical output by passing a module (or list of
  541. modules) to FlopCounterMode on construction. If you do not need hierarchical
  542. output, you do not need to use it with a module.
  543. Example usage
  544. .. code-block:: python
  545. mod = ...
  546. with FlopCounterMode(mod) as flop_counter:
  547. mod.sum().backward()
  548. """
  549. def __init__(
  550. self,
  551. mods: Optional[Union[torch.nn.Module, list[torch.nn.Module]]] = None,
  552. depth: int = 2,
  553. display: bool = True,
  554. custom_mapping: Optional[dict[Any, Any]] = None):
  555. super().__init__()
  556. self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
  557. self.depth = depth
  558. self.display = display
  559. self.mode: Optional[_FlopCounterMode] = None
  560. if custom_mapping is None:
  561. custom_mapping = {}
  562. if mods is not None:
  563. warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
  564. self.flop_registry = {
  565. **flop_registry,
  566. **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
  567. }
  568. self.mod_tracker = ModuleTracker()
  569. def get_total_flops(self) -> int:
  570. return sum(self.flop_counts['Global'].values())
  571. def get_flop_counts(self) -> dict[str, dict[Any, int]]:
  572. """Return the flop counts as a dictionary of dictionaries.
  573. The outer
  574. dictionary is keyed by module name, and the inner dictionary is keyed by
  575. operation name.
  576. Returns:
  577. Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
  578. """
  579. return {k: dict(v) for k, v in self.flop_counts.items()}
  580. def get_table(self, depth=None):
  581. if depth is None:
  582. depth = self.depth
  583. if depth is None:
  584. depth = 999999
  585. import tabulate
  586. tabulate.PRESERVE_WHITESPACE = True
  587. header = ["Module", "FLOP", "% Total"]
  588. values = []
  589. global_flops = self.get_total_flops()
  590. global_suffix = get_suffix_str(global_flops)
  591. is_global_subsumed = False
  592. def process_mod(mod_name, depth):
  593. nonlocal is_global_subsumed
  594. total_flops = sum(self.flop_counts[mod_name].values())
  595. is_global_subsumed |= total_flops >= global_flops
  596. padding = " " * depth
  597. values = []
  598. values.append([
  599. padding + mod_name,
  600. convert_num_with_suffix(total_flops, global_suffix),
  601. convert_to_percent_str(total_flops, global_flops)
  602. ])
  603. for k, v in self.flop_counts[mod_name].items():
  604. values.append([
  605. padding + " - " + str(k),
  606. convert_num_with_suffix(v, global_suffix),
  607. convert_to_percent_str(v, global_flops)
  608. ])
  609. return values
  610. for mod in sorted(self.flop_counts.keys()):
  611. if mod == 'Global':
  612. continue
  613. mod_depth = mod.count(".") + 1
  614. if mod_depth > depth:
  615. continue
  616. cur_values = process_mod(mod, mod_depth - 1)
  617. values.extend(cur_values)
  618. # We do a bit of messing around here to only output the "Global" value
  619. # if there are any FLOPs in there that aren't already fully contained by
  620. # a module.
  621. if 'Global' in self.flop_counts and not is_global_subsumed:
  622. for value in values:
  623. value[0] = " " + value[0]
  624. values = process_mod('Global', 0) + values
  625. if len(values) == 0:
  626. values = [["Global", "0", "0%"]]
  627. return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
  628. # NB: This context manager is NOT reentrant
  629. def __enter__(self):
  630. self.flop_counts.clear()
  631. self.mod_tracker.__enter__()
  632. self.mode = _FlopCounterMode(self)
  633. self.mode.__enter__()
  634. return self
  635. def __exit__(self, *args):
  636. assert self.mode is not None
  637. b = self.mode.__exit__(*args)
  638. self.mode = None # break cycles
  639. self.mod_tracker.__exit__()
  640. if self.display:
  641. print(self.get_table(self.depth))
  642. return b
  643. def _count_flops(self, func_packet, out, args, kwargs):
  644. if func_packet in self.flop_registry:
  645. flop_count_func = self.flop_registry[func_packet]
  646. flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
  647. for par in set(self.mod_tracker.parents):
  648. self.flop_counts[par][func_packet] += flop_count
  649. return out
  650. class _FlopCounterMode(TorchDispatchMode):
  651. supports_higher_order_operators = True
  652. def __init__(self, counter: FlopCounterMode):
  653. self.counter = counter
  654. def _execute_with_isolated_flop_counting(self, branch_fn, operands):
  655. """Execute a branch function and capture its FLOP counts without
  656. affecting self.counter.flop_counts
  657. Args:
  658. branch_fn: The branch function to execute
  659. operands: Arguments to pass to the branch function
  660. Returns:
  661. Tuple of (result, flop_counts) where result is the branch output
  662. and flop_counts is a copy of the FLOP counts after execution
  663. """
  664. import copy
  665. checkpointed_flop_counts = copy.copy(self.counter.flop_counts)
  666. with self:
  667. result = branch_fn(*operands)
  668. flop_counts = copy.copy(self.counter.flop_counts)
  669. self.counter.flop_counts = checkpointed_flop_counts
  670. return result, flop_counts
  671. def _handle_higher_order_ops(self, func, types, args, kwargs):
  672. if func not in {torch.ops.higher_order.cond, }:
  673. return NotImplemented
  674. # The flop counter for cond counts the upper bound of flops.
  675. # For example, if a matmul is executed 2 times in true branch
  676. # but only 1 time in the false branch, the flop counter will
  677. # record the larger number of flops, i.e. 2 times.
  678. if func is torch.ops.higher_order.cond:
  679. pred, true_branch, false_branch, operands = args
  680. # Step 1: Count flops for true branch and false branch separately
  681. true_out, true_flop_counts = self._execute_with_isolated_flop_counting(
  682. true_branch, operands
  683. )
  684. if true_out is NotImplemented:
  685. return NotImplemented
  686. false_out, false_flop_counts = self._execute_with_isolated_flop_counting(
  687. false_branch, operands
  688. )
  689. if false_out is NotImplemented:
  690. return NotImplemented
  691. # Step 2: merge flop counts
  692. all_mod_keys = set(true_flop_counts.keys()) | set(false_flop_counts.keys())
  693. merged_flop_counts = {}
  694. for outer_key in all_mod_keys:
  695. true_func_counts = true_flop_counts[outer_key]
  696. false_func_counts = false_flop_counts[outer_key]
  697. merged_func_counts = {}
  698. all_func_keys = set(true_func_counts.keys()) | set(false_func_counts.keys())
  699. for func_key in all_func_keys:
  700. true_val = true_func_counts.get(func_key, 0)
  701. false_val = false_func_counts.get(func_key, 0)
  702. merged_func_counts[func_key] = max(true_val, false_val)
  703. merged_flop_counts[outer_key] = merged_func_counts
  704. # Step 3: update the counter with merged counts
  705. for outer_key, inner_dict in merged_flop_counts.items():
  706. self.counter.flop_counts[outer_key].update(inner_dict)
  707. # It doesn't matter which one we return since true_fn and false_fn return
  708. # output with the same structure.
  709. return true_out
  710. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  711. kwargs = kwargs if kwargs else {}
  712. # Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
  713. if func in {torch.ops.aten.sym_is_contiguous.default,
  714. torch.ops.aten.is_contiguous.default,
  715. torch.ops.aten.is_contiguous.memory_format,
  716. torch.ops.aten.is_strides_like_format.default,
  717. torch.ops.aten.is_non_overlapping_and_dense.default,
  718. torch.ops.aten.size.default,
  719. torch.ops.aten.sym_size.default,
  720. torch.ops.aten.stride.default,
  721. torch.ops.aten.sym_stride.default,
  722. torch.ops.aten.storage_offset.default,
  723. torch.ops.aten.sym_storage_offset.default,
  724. torch.ops.aten.numel.default,
  725. torch.ops.aten.sym_numel.default,
  726. torch.ops.aten.dim.default,
  727. torch.ops.prim.layout.default}:
  728. return NotImplemented
  729. if isinstance(func, torch._ops.HigherOrderOperator):
  730. return self._handle_higher_order_ops(func, types, args, kwargs)
  731. # If we don't have func in flop_registry, see if it can decompose
  732. if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default:
  733. with self:
  734. r = func.decompose(*args, **kwargs)
  735. if r is not NotImplemented:
  736. return r
  737. # no further decomposition; execute & count flops
  738. out = func(*args, **kwargs)
  739. return self.counter._count_flops(func._overloadpacket, out, args, kwargs)