flops.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. _FLOPS_COMPUTE_FUNC_MAP = {}
  16. def prod(s):
  17. p = 1
  18. for v in s:
  19. p *= v
  20. return p
  21. def flops(op_type: str, input_shapes: dict, attrs: dict) -> int:
  22. """
  23. count FLOPs for operation.
  24. Args:
  25. op_type (str): the type of operation.
  26. input_shapes (dict): the shapes of inputs.
  27. attrs (dict): the attributes of the operation.
  28. Returns:
  29. the total FLOPs of the operation.
  30. """
  31. if op_type not in _FLOPS_COMPUTE_FUNC_MAP:
  32. return 0
  33. else:
  34. func = _FLOPS_COMPUTE_FUNC_MAP[op_type]
  35. try:
  36. flops = func(input_shapes, attrs)
  37. except Exception as e:
  38. return 0
  39. return flops
  40. def register_flops(op_type):
  41. """
  42. register flops computation function for operation.
  43. """
  44. def register(func):
  45. global _FLOPS_COMPUTE_FUNC_MAP
  46. _FLOPS_COMPUTE_FUNC_MAP[op_type] = func
  47. return func
  48. return register
  49. @register_flops("c_embedding")
  50. def _c_embedding_flops(input_shapes, attrs):
  51. """FLOPs computation for c_embedding op.
  52. For c_embedding(input):
  53. equation: flops = 0
  54. """
  55. return 0
  56. @register_flops("conv2d")
  57. def _conv2d_flops(input_shapes, attrs):
  58. """FLOPs computation for conv2d op.
  59. For conv2d(input,filter):
  60. active_elements = batch_size * numel(output)
  61. conv_flops = 2 * macs_per_position_conv * active_elements
  62. bias_flops = out_channels * active_elements
  63. equation: flops = conv_flops + bias_flops
  64. """
  65. bias = (
  66. input_shapes.get('Bias')[0]
  67. if len(input_shapes.get('Bias')) > 0
  68. else None
  69. )
  70. input = input_shapes.get('Input')[0]
  71. weight = input_shapes.get('Filter')[0]
  72. padding = attrs.get('paddings')
  73. stride = attrs.get('strides')
  74. dilation = attrs.get('dilations')
  75. groups = attrs.get('groups')
  76. batch_size = input[0]
  77. in_channels = input[1]
  78. out_channels = weight[0]
  79. kernel_dims = list(weight[2:])
  80. input_dims = list(input[2:])
  81. length = len(input_dims)
  82. paddings = (
  83. padding
  84. if isinstance(padding, list)
  85. else [
  86. padding,
  87. ]
  88. * length
  89. )
  90. strides = (
  91. stride
  92. if isinstance(stride, list)
  93. else [
  94. stride,
  95. ]
  96. * length
  97. )
  98. dilations = (
  99. dilation
  100. if isinstance(dilation, list)
  101. else [
  102. dilation,
  103. ]
  104. * length
  105. )
  106. output_dims = []
  107. for idx, input_dim in enumerate(input_dims):
  108. output_dim = (
  109. input_dim
  110. + 2 * paddings[idx]
  111. - (dilations[idx] * (kernel_dims[idx] - 1) + 1)
  112. ) // strides[idx] + 1
  113. output_dims.append(output_dim)
  114. filters_per_channel = out_channels // groups
  115. macs_conv_per_position = (
  116. prod(kernel_dims) * in_channels * filters_per_channel
  117. )
  118. active_elements = batch_size * prod(output_dims)
  119. overall_conv_macs = macs_conv_per_position * active_elements
  120. overall_conv_flops = 2 * overall_conv_macs
  121. overall_bias_flops = 0
  122. if bias is not None:
  123. overall_bias_flops = out_channels * active_elements
  124. return overall_conv_flops + overall_bias_flops
  125. @register_flops("dropout")
  126. def _dropout_flops(input_shapes, attrs):
  127. """FLOPs computation for dropout op.
  128. For dropout(input):
  129. equation: flops = 0
  130. """
  131. return 0
  132. def _elementwise_flops_compute(input_shapes, attrs):
  133. input_x = input_shapes.get("X")[0]
  134. input_y = input_shapes.get("Y")[0]
  135. dim_x = len(input_x)
  136. dim_y = len(input_y)
  137. dim_output = max(dim_x, dim_y)
  138. output = []
  139. for i in range(dim_output):
  140. in_x = input_x[dim_x - 1 - i] if i < dim_x else 1
  141. in_y = input_y[dim_y - 1 - i] if i < dim_y else 1
  142. output.append(max(in_x, in_y))
  143. return prod(output)
  144. @register_flops("elementwise_add")
  145. def _elementwise_add_flops(input_shapes, attrs):
  146. """FLOPs computation for elementwise_add op.
  147. For elementwise_add(input,other):
  148. input_shapes = [shape_of_input, shape_of_other]
  149. shape_of_input = [dim1, dim2, dim3 ...]
  150. shape_of_other = [odim1, odim2, odim3...]
  151. equation: flops = max(dim1, odim1) * max(dim2, odim2) * max()...
  152. """
  153. return _elementwise_flops_compute(input_shapes, attrs)
  154. @register_flops("elementwise_mul")
  155. def _elementwise_mul_flops(input_shapes, attrs):
  156. """FLOPs computation for elementwise_mul op.
  157. For elementwise_mul(input,other):
  158. input_shapes = [shape_of_input, shape_of_other]
  159. shape_of_input = [dim1, dim2, dim3 ...]
  160. shape_of_other = [odim1, odim2, odim3...]
  161. equation: flops = max(dim1, odim1) * max(dim2, odim2)* max()...
  162. """
  163. return _elementwise_flops_compute(input_shapes, attrs)
  164. @register_flops("elementwise_div")
  165. def _elementwise_div_flops(input_shapes, attrs):
  166. """FLOPs computation for elementwise_div op.
  167. For elementwise_div(input,other):
  168. input_shapes = [shape_of_input, shape_of_other]
  169. shape_of_input = [dim1, dim2, dim3 ...]
  170. shape_of_other = [odim1, odim2, odim3...]
  171. equation: flops = max(dim1,odim1)*max(dim2,odim2)*max()...
  172. """
  173. return _elementwise_flops_compute(input_shapes, attrs)
  174. @register_flops("gelu")
  175. def _gelu_flops(input_shapes, attrs):
  176. """FLOPs computation for gelu op.
  177. For gelu(input):
  178. equation: flops = 5 * (numel)total number of elements in the input tensor.
  179. """
  180. input = input_shapes.get('X')[0]
  181. return prod(input) * 5
  182. @register_flops("layer_norm")
  183. def _layer_norm_flops(input_shapes, attrs):
  184. """FLOPs computation for layer_norm op.
  185. For layer_norm(input):
  186. equation:
  187. 1): WITHOUT epsilon flops = 7 * (numel)total number of elements in the input tensor.
  188. 2): WITH epsilon flops = 8 * (numel)total number of elements in the input tensor.
  189. """
  190. input = input_shapes.get('X')[0]
  191. flops = prod(input) * 7
  192. if attrs.get('epsilon'):
  193. flops += prod(input)
  194. return flops
  195. @register_flops("matmul")
  196. def _matmul_flops(input_shapes, attrs):
  197. """FLOPs computation for matmul op.
  198. For matmul(input,other):
  199. input_shapes = [shape_of_input, shape_of_other]
  200. shape_of_input = [dim1,dim2 ...dim_n_1,dim_n] length:n
  201. shape_of_other = [odim1,odim2 ... odim(n-m)... odim_m_1,dim_m] length:m
  202. suppose n > m and dim_n = odim_m_1:
  203. shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1)) ... dim_n_1, dim_m]
  204. equation: flops = 2 * numel(output) * dim_n
  205. """
  206. x_shape = copy.deepcopy(
  207. input_shapes.get("X", input_shapes.get("x", [[0]]))[0]
  208. )
  209. y_shape = copy.deepcopy(
  210. input_shapes.get("Y", input_shapes.get("y", [[0]]))[0]
  211. )
  212. if attrs.get('transpose_X') or attrs.get('transpose_x'):
  213. x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
  214. if attrs.get('transpose_Y') or attrs.get('transpose_y'):
  215. y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
  216. dim_x = len(x_shape)
  217. dim_y = len(y_shape)
  218. output_len = max(dim_x, dim_y)
  219. output_shape = []
  220. for idx in range(output_len, 2, -1):
  221. x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1
  222. y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1
  223. output_shape.append(max(x_idx, y_idx))
  224. macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1]
  225. return 2 * macs
  226. @register_flops("matmul_v2")
  227. def _matmul_v2_flops(input_shapes, attrs):
  228. """FLOPs computation for matmul_v2 op.
  229. For matmul_v2(input,other):
  230. input_shapes = [shape_of_input, shape_of_other]
  231. shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n
  232. shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
  233. suppose n > m and dim_n = odim_m_1:
  234. shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
  235. equation: flops = 2 * numel(outputs) * dim_n
  236. """
  237. x_shape = copy.deepcopy(input_shapes.get('X')[0])
  238. y_shape = copy.deepcopy(input_shapes.get('Y')[0])
  239. if attrs.get('trans_x'):
  240. x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
  241. if attrs.get('trans_y'):
  242. y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
  243. dim_x = len(x_shape)
  244. dim_y = len(y_shape)
  245. output_len = max(dim_x, dim_y)
  246. output_shape = []
  247. for idx in range(output_len, 2, -1):
  248. x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1
  249. y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1
  250. output_shape.append(max(x_idx, y_idx))
  251. macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1]
  252. return 2 * macs
  253. def _relu_class_flops(input_shapes, attrs):
  254. """FLOPs computation for relu_like ops.
  255. For elu/leaky_relu/prelu/relu/relu6/silu (input):
  256. equation: flops = (numel)total number of elements in the input tensor.
  257. """
  258. input = input_shapes.get('X')[0]
  259. return prod(input)
  260. @register_flops("elu")
  261. def _elu_flops(input_shapes, attrs):
  262. return _relu_class_flops(input_shapes, attrs)
  263. @register_flops("leaky_relu")
  264. def _leaky_relu_flops(input_shapes, attrs):
  265. return _relu_class_flops(input_shapes, attrs)
  266. @register_flops("prelu")
  267. def _prelu_flops(input_shapes, attrs):
  268. return _relu_class_flops(input_shapes, attrs)
  269. @register_flops("relu")
  270. def _relu_flops(input_shapes, attrs):
  271. return _relu_class_flops(input_shapes, attrs)
  272. @register_flops("relu6")
  273. def _relu6_flops(input_shapes, attrs):
  274. return _relu_class_flops(input_shapes, attrs)
  275. @register_flops("silu")
  276. def _silu_flops(input_shapes, attrs):
  277. return _relu_class_flops(input_shapes, attrs)
  278. @register_flops("reshape2")
  279. def _reshape2_flops(input_shapes, attrs):
  280. """FLOPs computation for reshape2 op.
  281. For reshape2(input):
  282. equation: flops = 0
  283. """
  284. return 0
  285. @register_flops("softmax")
  286. def _softmax_flops(input_shapes, attrs):
  287. """FLOPs computation for softmax op.
  288. For softmax(input):
  289. equation: flops = 3 * (numel)total number of elements in the input tensor.
  290. """
  291. input = input_shapes.get('X')[0]
  292. return prod(input) * 3
  293. @register_flops("transpose2")
  294. def _transpose2_flops(input_shapes, attrs):
  295. """FLOPs computation for transpose2 op.
  296. For transpose2(input):
  297. equation: flops = 0
  298. """
  299. return 0
  300. @register_flops("pool")
  301. def _pool_flops(input_shapes, attrs):
  302. """FLOPs computation for pool op.
  303. For pool(input):
  304. equation: flops = (numel)total number of elements in the input tensor.
  305. """
  306. input = input_shapes.get('X')[0]
  307. return prod(input)