static_flops.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import OrderedDict
  15. import numpy as np
  16. from paddle.static import Program, Variable
  17. __all__ = []
  18. class VarWrapper:
  19. def __init__(self, var, graph):
  20. assert isinstance(var, Variable)
  21. assert isinstance(graph, GraphWrapper)
  22. self._var = var
  23. self._graph = graph
  24. def name(self):
  25. """
  26. Get the name of the variable.
  27. """
  28. return self._var.name
  29. def shape(self):
  30. """
  31. Get the shape of the variable.
  32. """
  33. return self._var.shape
  34. class OpWrapper:
  35. def __init__(self, op, graph):
  36. assert isinstance(graph, GraphWrapper)
  37. self._op = op
  38. self._graph = graph
  39. def type(self):
  40. """
  41. Get the type of this operator.
  42. """
  43. return self._op.type
  44. def inputs(self, name):
  45. """
  46. Get all the variables by the input name.
  47. """
  48. if name in self._op.input_names:
  49. return [
  50. self._graph.var(var_name) for var_name in self._op.input(name)
  51. ]
  52. else:
  53. return []
  54. def outputs(self, name):
  55. """
  56. Get all the variables by the output name.
  57. """
  58. return [self._graph.var(var_name) for var_name in self._op.output(name)]
  59. class GraphWrapper:
  60. """
  61. It is a wrapper of paddle.base.framework.IrGraph with some special functions
  62. for paddle slim framework.
  63. Args:
  64. program(framework.Program): A program with
  65. in_nodes(dict): A dict to indicate the input nodes of the graph.
  66. The key is user-defined and human-readable name.
  67. The value is the name of Variable.
  68. out_nodes(dict): A dict to indicate the input nodes of the graph.
  69. The key is user-defined and human-readable name.
  70. The value is the name of Variable.
  71. """
  72. def __init__(self, program=None, in_nodes=[], out_nodes=[]):
  73. """ """
  74. super().__init__()
  75. self.program = Program() if program is None else program
  76. self.persistables = {}
  77. self.teacher_persistables = {}
  78. for var in self.program.list_vars():
  79. if var.persistable:
  80. self.persistables[var.name] = var
  81. self.compiled_graph = None
  82. in_nodes = [] if in_nodes is None else in_nodes
  83. out_nodes = [] if out_nodes is None else out_nodes
  84. self.in_nodes = OrderedDict(in_nodes)
  85. self.out_nodes = OrderedDict(out_nodes)
  86. self._attrs = OrderedDict()
  87. def ops(self):
  88. """
  89. Return all operator nodes included in the graph as a set.
  90. """
  91. ops = []
  92. for block in self.program.blocks:
  93. for op in block.ops:
  94. ops.append(OpWrapper(op, self))
  95. return ops
  96. def var(self, name):
  97. """
  98. Get the variable by variable name.
  99. """
  100. for block in self.program.blocks:
  101. if block.has_var(name):
  102. return VarWrapper(block.var(name), self)
  103. return None
  104. def count_convNd(op):
  105. filter_shape = op.inputs("Filter")[0].shape()
  106. filter_ops = np.prod(filter_shape[1:])
  107. bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
  108. output_numel = np.prod(op.outputs("Output")[0].shape()[1:])
  109. total_ops = output_numel * (filter_ops + bias_ops)
  110. total_ops = abs(total_ops)
  111. return total_ops
  112. def count_leaky_relu(op):
  113. total_ops = np.prod(op.outputs("Output")[0].shape()[1:])
  114. return total_ops
  115. def count_bn(op):
  116. output_numel = np.prod(op.outputs("Y")[0].shape()[1:])
  117. total_ops = 2 * output_numel
  118. total_ops = abs(total_ops)
  119. return total_ops
  120. def count_linear(op):
  121. total_mul = op.inputs("Y")[0].shape()[0]
  122. numel = np.prod(op.outputs("Out")[0].shape()[1:])
  123. total_ops = total_mul * numel
  124. total_ops = abs(total_ops)
  125. return total_ops
  126. def count_pool2d(op):
  127. input_shape = op.inputs("X")[0].shape()
  128. output_shape = op.outputs('Out')[0].shape()
  129. kernel = np.array(input_shape[2:]) // np.array(output_shape[2:])
  130. total_add = np.prod(kernel)
  131. total_div = 1
  132. kernel_ops = total_add + total_div
  133. num_elements = np.prod(output_shape[1:])
  134. total_ops = kernel_ops * num_elements
  135. total_ops = abs(total_ops)
  136. return total_ops
  137. def count_element_op(op):
  138. input_shape = op.inputs("X")[0].shape()
  139. total_ops = np.prod(input_shape[1:])
  140. total_ops = abs(total_ops)
  141. return total_ops
  142. def _graph_flops(graph, detail=False):
  143. assert isinstance(graph, GraphWrapper)
  144. flops = 0
  145. op_flops = 0
  146. table = Table(["OP Type", 'Param name', "Flops"])
  147. for op in graph.ops():
  148. param_name = ''
  149. if op.type() in ['conv2d', 'depthwise_conv2d']:
  150. op_flops = count_convNd(op)
  151. flops += op_flops
  152. param_name = op.inputs("Filter")[0].name()
  153. elif op.type() == 'pool2d':
  154. op_flops = count_pool2d(op)
  155. flops += op_flops
  156. elif op.type() in ['mul', 'matmul']:
  157. op_flops = count_linear(op)
  158. flops += op_flops
  159. param_name = op.inputs("Y")[0].name()
  160. elif op.type() == 'batch_norm':
  161. op_flops = count_bn(op)
  162. flops += op_flops
  163. elif op.type().startswith('element'):
  164. op_flops = count_element_op(op)
  165. flops += op_flops
  166. if op_flops != 0:
  167. table.add_row([op.type(), param_name, op_flops])
  168. op_flops = 0
  169. if detail:
  170. table.print_table()
  171. return flops
  172. def static_flops(program, print_detail=False):
  173. graph = GraphWrapper(program)
  174. return _graph_flops(graph, detail=print_detail)
  175. class Table:
  176. def __init__(self, table_heads):
  177. self.table_heads = table_heads
  178. self.table_len = []
  179. self.data = []
  180. self.col_num = len(table_heads)
  181. for head in table_heads:
  182. self.table_len.append(len(head))
  183. def add_row(self, row_str):
  184. if not isinstance(row_str, list):
  185. print('The row_str should be a list')
  186. if len(row_str) != self.col_num:
  187. print(
  188. f'The length of row data should be equal the length of table heads, but the data: {len(row_str)} is not equal table heads {self.col_num}'
  189. )
  190. for i in range(self.col_num):
  191. if len(str(row_str[i])) > self.table_len[i]:
  192. self.table_len[i] = len(str(row_str[i]))
  193. self.data.append(row_str)
  194. def print_row(self, row):
  195. string = ''
  196. for i in range(self.col_num):
  197. string += '|' + str(row[i]).center(self.table_len[i] + 2)
  198. string += '|'
  199. print(string)
  200. def print_shelf(self):
  201. string = ''
  202. for length in self.table_len:
  203. string += '+'
  204. string += '-' * (length + 2)
  205. string += '+'
  206. print(string)
  207. def print_table(self):
  208. self.print_shelf()
  209. self.print_row(self.table_heads)
  210. self.print_shelf()
  211. for data in self.data:
  212. self.print_row(data)
  213. self.print_shelf()