dynamic_flops.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. import numpy as np
  16. import paddle
  17. from paddle import nn
  18. from paddle.jit.dy2static.program_translator import unwrap_decorators
  19. from .static_flops import Table, static_flops
  20. __all__ = []
  21. def flops(net, input_size, custom_ops=None, print_detail=False):
  22. """Print a table about the FLOPs of network.
  23. Args:
  24. net (paddle.nn.Layer||paddle.static.Program): The network which could be a instance of paddle.nn.Layer in
  25. dygraph or paddle.static.Program in static graph.
  26. input_size (list): size of input tensor. Note that the batch_size in argument ``input_size`` only support 1.
  27. custom_ops (A dict of function, optional): A dictionary which key is the class of specific operation such as
  28. paddle.nn.Conv2D and the value is the function used to count the FLOPs of this operation. This
  29. argument only work when argument ``net`` is an instance of paddle.nn.Layer. The details could be found
  30. in following example code. Default is None.
  31. print_detail (bool, optional): Whether to print the detail information, like FLOPs per layer, about the net FLOPs.
  32. Default is False.
  33. Returns:
  34. Int: A number about the FLOPs of total network.
  35. Examples:
  36. .. code-block:: python
  37. >>> import paddle
  38. >>> import paddle.nn as nn
  39. >>> class LeNet(nn.Layer):
  40. ... def __init__(self, num_classes=10):
  41. ... super().__init__()
  42. ... self.num_classes = num_classes
  43. ... self.features = nn.Sequential(
  44. ... nn.Conv2D(1, 6, 3, stride=1, padding=1),
  45. ... nn.ReLU(),
  46. ... nn.MaxPool2D(2, 2),
  47. ... nn.Conv2D(6, 16, 5, stride=1, padding=0),
  48. ... nn.ReLU(),
  49. ... nn.MaxPool2D(2, 2))
  50. ...
  51. ... if num_classes > 0:
  52. ... self.fc = nn.Sequential(
  53. ... nn.Linear(400, 120),
  54. ... nn.Linear(120, 84),
  55. ... nn.Linear(84, 10))
  56. ...
  57. ... def forward(self, inputs):
  58. ... x = self.features(inputs)
  59. ...
  60. ... if self.num_classes > 0:
  61. ... x = paddle.flatten(x, 1)
  62. ... x = self.fc(x)
  63. ... return x
  64. ...
  65. >>> lenet = LeNet()
  66. >>> # m is the instance of nn.Layer, x is the input of layer, y is the output of layer.
  67. >>> def count_leaky_relu(m, x, y):
  68. ... x = x[0]
  69. ... nelements = x.numel()
  70. ... m.total_ops += int(nelements)
  71. ...
  72. >>> FLOPs = paddle.flops(lenet,
  73. ... [1, 1, 28, 28],
  74. ... custom_ops= {nn.LeakyReLU: count_leaky_relu},
  75. ... print_detail=True)
  76. >>> print(FLOPs)
  77. <class 'paddle.nn.layer.conv.Conv2D'>'s flops has been counted
  78. <class 'paddle.nn.layer.activation.ReLU'>'s flops has been counted
  79. Cannot find suitable count function for <class 'paddle.nn.layer.pooling.MaxPool2D'>. Treat it as zero FLOPs.
  80. <class 'paddle.nn.layer.common.Linear'>'s flops has been counted
  81. +--------------+-----------------+-----------------+--------+--------+
  82. | Layer Name | Input Shape | Output Shape | Params | Flops |
  83. +--------------+-----------------+-----------------+--------+--------+
  84. | conv2d_0 | [1, 1, 28, 28] | [1, 6, 28, 28] | 60 | 47040 |
  85. | re_lu_0 | [1, 6, 28, 28] | [1, 6, 28, 28] | 0 | 0 |
  86. | max_pool2d_0 | [1, 6, 28, 28] | [1, 6, 14, 14] | 0 | 0 |
  87. | conv2d_1 | [1, 6, 14, 14] | [1, 16, 10, 10] | 2416 | 241600 |
  88. | re_lu_1 | [1, 16, 10, 10] | [1, 16, 10, 10] | 0 | 0 |
  89. | max_pool2d_1 | [1, 16, 10, 10] | [1, 16, 5, 5] | 0 | 0 |
  90. | linear_0 | [1, 400] | [1, 120] | 48120 | 48000 |
  91. | linear_1 | [1, 120] | [1, 84] | 10164 | 10080 |
  92. | linear_2 | [1, 84] | [1, 10] | 850 | 840 |
  93. +--------------+-----------------+-----------------+--------+--------+
  94. Total Flops: 347560 Total Params: 61610
  95. 347560
  96. """
  97. if isinstance(net, nn.Layer):
  98. # If net is a dy2stat model, net.forward is StaticFunction instance,
  99. # we set net.forward to original forward function.
  100. _, net.forward = unwrap_decorators(net.forward)
  101. inputs = paddle.randn(input_size)
  102. return dynamic_flops(
  103. net, inputs=inputs, custom_ops=custom_ops, print_detail=print_detail
  104. )
  105. elif isinstance(net, paddle.static.Program):
  106. return static_flops(net, print_detail=print_detail)
  107. else:
  108. warnings.warn(
  109. "Your model must be an instance of paddle.nn.Layer or paddle.static.Program."
  110. )
  111. return -1
  112. def count_convNd(m, x, y):
  113. x = x[0]
  114. kernel_ops = np.prod(m.weight.shape[2:])
  115. bias_ops = 1 if m.bias is not None else 0
  116. total_ops = int(y.numel()) * (
  117. x.shape[1] / m._groups * kernel_ops + bias_ops
  118. )
  119. m.total_ops += abs(int(total_ops))
  120. def count_leaky_relu(m, x, y):
  121. x = x[0]
  122. nelements = x.numel()
  123. m.total_ops += int(nelements)
  124. def count_bn(m, x, y):
  125. x = x[0]
  126. nelements = x.numel()
  127. if not m.training:
  128. total_ops = 2 * nelements
  129. m.total_ops += abs(int(total_ops))
  130. def count_linear(m, x, y):
  131. total_mul = m.weight.shape[0]
  132. num_elements = y.numel()
  133. total_ops = total_mul * num_elements
  134. m.total_ops += abs(int(total_ops))
  135. def count_avgpool(m, x, y):
  136. kernel_ops = 1
  137. num_elements = y.numel()
  138. total_ops = kernel_ops * num_elements
  139. m.total_ops += int(total_ops)
  140. def count_adap_avgpool(m, x, y):
  141. kernel = np.array(x[0].shape[2:]) // np.array(y.shape[2:])
  142. total_add = np.prod(kernel)
  143. total_div = 1
  144. kernel_ops = total_add + total_div
  145. num_elements = y.numel()
  146. total_ops = kernel_ops * num_elements
  147. m.total_ops += abs(int(total_ops))
  148. def count_zero_ops(m, x, y):
  149. m.total_ops += 0
  150. def count_parameters(m, x, y):
  151. total_params = 0
  152. for p in m.parameters():
  153. total_params += p.numel()
  154. m.total_params[0] = abs(int(total_params))
  155. def count_io_info(m, x, y):
  156. m.register_buffer('input_shape', paddle.to_tensor(x[0].shape))
  157. if isinstance(y, (list, tuple)):
  158. m.register_buffer('output_shape', paddle.to_tensor(y[0].shape))
  159. else:
  160. m.register_buffer('output_shape', paddle.to_tensor(y.shape))
  161. register_hooks = {
  162. nn.Conv1D: count_convNd,
  163. nn.Conv2D: count_convNd,
  164. nn.Conv3D: count_convNd,
  165. nn.Conv1DTranspose: count_convNd,
  166. nn.Conv2DTranspose: count_convNd,
  167. nn.Conv3DTranspose: count_convNd,
  168. nn.layer.norm.BatchNorm2D: count_bn,
  169. nn.BatchNorm: count_bn,
  170. nn.ReLU: count_zero_ops,
  171. nn.ReLU6: count_zero_ops,
  172. nn.LeakyReLU: count_leaky_relu,
  173. nn.Linear: count_linear,
  174. nn.Dropout: count_zero_ops,
  175. nn.AvgPool1D: count_avgpool,
  176. nn.AvgPool2D: count_avgpool,
  177. nn.AvgPool3D: count_avgpool,
  178. nn.AdaptiveAvgPool1D: count_adap_avgpool,
  179. nn.AdaptiveAvgPool2D: count_adap_avgpool,
  180. nn.AdaptiveAvgPool3D: count_adap_avgpool,
  181. }
  182. def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
  183. handler_collection = []
  184. types_collection = set()
  185. if custom_ops is None:
  186. custom_ops = {}
  187. def add_hooks(m):
  188. if len(list(m.children())) > 0:
  189. return
  190. m.register_buffer('total_ops', paddle.zeros([1], dtype='int64'))
  191. m.register_buffer('total_params', paddle.zeros([1], dtype='int64'))
  192. m_type = type(m)
  193. flops_fn = None
  194. if m_type in custom_ops:
  195. flops_fn = custom_ops[m_type]
  196. if m_type not in types_collection:
  197. print(f"Customize Function has been applied to {m_type}")
  198. elif m_type in register_hooks:
  199. flops_fn = register_hooks[m_type]
  200. if m_type not in types_collection:
  201. print(f"{m_type}'s flops has been counted")
  202. else:
  203. if m_type not in types_collection:
  204. print(
  205. f"Cannot find suitable count function for {m_type}. Treat it as zero FLOPs."
  206. )
  207. if flops_fn is not None:
  208. flops_handler = m.register_forward_post_hook(flops_fn)
  209. handler_collection.append(flops_handler)
  210. params_handler = m.register_forward_post_hook(count_parameters)
  211. io_handler = m.register_forward_post_hook(count_io_info)
  212. handler_collection.append(params_handler)
  213. handler_collection.append(io_handler)
  214. types_collection.add(m_type)
  215. training = model.training
  216. model.eval()
  217. model.apply(add_hooks)
  218. with paddle.framework.no_grad():
  219. model(inputs)
  220. total_ops = 0
  221. total_params = 0
  222. for m in model.sublayers():
  223. if len(list(m.children())) > 0:
  224. continue
  225. if {
  226. 'total_ops',
  227. 'total_params',
  228. 'input_shape',
  229. 'output_shape',
  230. }.issubset(set(m._buffers.keys())):
  231. total_ops += m.total_ops
  232. total_params += m.total_params
  233. if training:
  234. model.train()
  235. for handler in handler_collection:
  236. handler.remove()
  237. table = Table(
  238. ["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"]
  239. )
  240. for n, m in model.named_sublayers():
  241. if len(list(m.children())) > 0:
  242. continue
  243. if {
  244. 'total_ops',
  245. 'total_params',
  246. 'input_shape',
  247. 'output_shape',
  248. }.issubset(set(m._buffers.keys())):
  249. table.add_row(
  250. [
  251. m.full_name(),
  252. list(m.input_shape.numpy()),
  253. list(m.output_shape.numpy()),
  254. int(m.total_params),
  255. int(m.total_ops),
  256. ]
  257. )
  258. m._buffers.pop("total_ops")
  259. m._buffers.pop("total_params")
  260. m._buffers.pop('input_shape')
  261. m._buffers.pop('output_shape')
  262. if print_detail:
  263. table.print_table()
  264. print(
  265. f'Total Flops: {int(total_ops)} Total Params: {int(total_params)}'
  266. )
  267. return int(total_ops)