| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections import OrderedDict
- import numpy as np
- from paddle.static import Program, Variable
- __all__ = []
- class VarWrapper:
- def __init__(self, var, graph):
- assert isinstance(var, Variable)
- assert isinstance(graph, GraphWrapper)
- self._var = var
- self._graph = graph
- def name(self):
- """
- Get the name of the variable.
- """
- return self._var.name
- def shape(self):
- """
- Get the shape of the variable.
- """
- return self._var.shape
- class OpWrapper:
- def __init__(self, op, graph):
- assert isinstance(graph, GraphWrapper)
- self._op = op
- self._graph = graph
- def type(self):
- """
- Get the type of this operator.
- """
- return self._op.type
- def inputs(self, name):
- """
- Get all the variables by the input name.
- """
- if name in self._op.input_names:
- return [
- self._graph.var(var_name) for var_name in self._op.input(name)
- ]
- else:
- return []
- def outputs(self, name):
- """
- Get all the variables by the output name.
- """
- return [self._graph.var(var_name) for var_name in self._op.output(name)]
- class GraphWrapper:
- """
- It is a wrapper of paddle.base.framework.IrGraph with some special functions
- for paddle slim framework.
- Args:
- program(framework.Program): A program with
- in_nodes(dict): A dict to indicate the input nodes of the graph.
- The key is user-defined and human-readable name.
- The value is the name of Variable.
- out_nodes(dict): A dict to indicate the input nodes of the graph.
- The key is user-defined and human-readable name.
- The value is the name of Variable.
- """
- def __init__(self, program=None, in_nodes=[], out_nodes=[]):
- """ """
- super().__init__()
- self.program = Program() if program is None else program
- self.persistables = {}
- self.teacher_persistables = {}
- for var in self.program.list_vars():
- if var.persistable:
- self.persistables[var.name] = var
- self.compiled_graph = None
- in_nodes = [] if in_nodes is None else in_nodes
- out_nodes = [] if out_nodes is None else out_nodes
- self.in_nodes = OrderedDict(in_nodes)
- self.out_nodes = OrderedDict(out_nodes)
- self._attrs = OrderedDict()
- def ops(self):
- """
- Return all operator nodes included in the graph as a set.
- """
- ops = []
- for block in self.program.blocks:
- for op in block.ops:
- ops.append(OpWrapper(op, self))
- return ops
- def var(self, name):
- """
- Get the variable by variable name.
- """
- for block in self.program.blocks:
- if block.has_var(name):
- return VarWrapper(block.var(name), self)
- return None
- def count_convNd(op):
- filter_shape = op.inputs("Filter")[0].shape()
- filter_ops = np.prod(filter_shape[1:])
- bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
- output_numel = np.prod(op.outputs("Output")[0].shape()[1:])
- total_ops = output_numel * (filter_ops + bias_ops)
- total_ops = abs(total_ops)
- return total_ops
- def count_leaky_relu(op):
- total_ops = np.prod(op.outputs("Output")[0].shape()[1:])
- return total_ops
- def count_bn(op):
- output_numel = np.prod(op.outputs("Y")[0].shape()[1:])
- total_ops = 2 * output_numel
- total_ops = abs(total_ops)
- return total_ops
- def count_linear(op):
- total_mul = op.inputs("Y")[0].shape()[0]
- numel = np.prod(op.outputs("Out")[0].shape()[1:])
- total_ops = total_mul * numel
- total_ops = abs(total_ops)
- return total_ops
- def count_pool2d(op):
- input_shape = op.inputs("X")[0].shape()
- output_shape = op.outputs('Out')[0].shape()
- kernel = np.array(input_shape[2:]) // np.array(output_shape[2:])
- total_add = np.prod(kernel)
- total_div = 1
- kernel_ops = total_add + total_div
- num_elements = np.prod(output_shape[1:])
- total_ops = kernel_ops * num_elements
- total_ops = abs(total_ops)
- return total_ops
- def count_element_op(op):
- input_shape = op.inputs("X")[0].shape()
- total_ops = np.prod(input_shape[1:])
- total_ops = abs(total_ops)
- return total_ops
- def _graph_flops(graph, detail=False):
- assert isinstance(graph, GraphWrapper)
- flops = 0
- op_flops = 0
- table = Table(["OP Type", 'Param name', "Flops"])
- for op in graph.ops():
- param_name = ''
- if op.type() in ['conv2d', 'depthwise_conv2d']:
- op_flops = count_convNd(op)
- flops += op_flops
- param_name = op.inputs("Filter")[0].name()
- elif op.type() == 'pool2d':
- op_flops = count_pool2d(op)
- flops += op_flops
- elif op.type() in ['mul', 'matmul']:
- op_flops = count_linear(op)
- flops += op_flops
- param_name = op.inputs("Y")[0].name()
- elif op.type() == 'batch_norm':
- op_flops = count_bn(op)
- flops += op_flops
- elif op.type().startswith('element'):
- op_flops = count_element_op(op)
- flops += op_flops
- if op_flops != 0:
- table.add_row([op.type(), param_name, op_flops])
- op_flops = 0
- if detail:
- table.print_table()
- return flops
- def static_flops(program, print_detail=False):
- graph = GraphWrapper(program)
- return _graph_flops(graph, detail=print_detail)
- class Table:
- def __init__(self, table_heads):
- self.table_heads = table_heads
- self.table_len = []
- self.data = []
- self.col_num = len(table_heads)
- for head in table_heads:
- self.table_len.append(len(head))
- def add_row(self, row_str):
- if not isinstance(row_str, list):
- print('The row_str should be a list')
- if len(row_str) != self.col_num:
- print(
- f'The length of row data should be equal the length of table heads, but the data: {len(row_str)} is not equal table heads {self.col_num}'
- )
- for i in range(self.col_num):
- if len(str(row_str[i])) > self.table_len[i]:
- self.table_len[i] = len(str(row_str[i]))
- self.data.append(row_str)
- def print_row(self, row):
- string = ''
- for i in range(self.col_num):
- string += '|' + str(row[i]).center(self.table_len[i] + 2)
- string += '|'
- print(string)
- def print_shelf(self):
- string = ''
- for length in self.table_len:
- string += '+'
- string += '-' * (length + 2)
- string += '+'
- print(string)
- def print_table(self):
- self.print_shelf()
- self.print_row(self.table_heads)
- self.print_shelf()
- for data in self.data:
- self.print_row(data)
- self.print_shelf()
|