| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- _FLOPS_COMPUTE_FUNC_MAP = {}
- def prod(s):
- p = 1
- for v in s:
- p *= v
- return p
- def flops(op_type: str, input_shapes: dict, attrs: dict) -> int:
- """
- count FLOPs for operation.
- Args:
- op_type (str): the type of operation.
- input_shapes (dict): the shapes of inputs.
- attrs (dict): the attributes of the operation.
- Returns:
- the total FLOPs of the operation.
- """
- if op_type not in _FLOPS_COMPUTE_FUNC_MAP:
- return 0
- else:
- func = _FLOPS_COMPUTE_FUNC_MAP[op_type]
- try:
- flops = func(input_shapes, attrs)
- except Exception as e:
- return 0
- return flops
- def register_flops(op_type):
- """
- register flops computation function for operation.
- """
- def register(func):
- global _FLOPS_COMPUTE_FUNC_MAP
- _FLOPS_COMPUTE_FUNC_MAP[op_type] = func
- return func
- return register
- @register_flops("c_embedding")
- def _c_embedding_flops(input_shapes, attrs):
- """FLOPs computation for c_embedding op.
- For c_embedding(input):
- equation: flops = 0
- """
- return 0
- @register_flops("conv2d")
- def _conv2d_flops(input_shapes, attrs):
- """FLOPs computation for conv2d op.
- For conv2d(input,filter):
- active_elements = batch_size * numel(output)
- conv_flops = 2 * macs_per_position_conv * active_elements
- bias_flops = out_channels * active_elements
- equation: flops = conv_flops + bias_flops
- """
- bias = (
- input_shapes.get('Bias')[0]
- if len(input_shapes.get('Bias')) > 0
- else None
- )
- input = input_shapes.get('Input')[0]
- weight = input_shapes.get('Filter')[0]
- padding = attrs.get('paddings')
- stride = attrs.get('strides')
- dilation = attrs.get('dilations')
- groups = attrs.get('groups')
- batch_size = input[0]
- in_channels = input[1]
- out_channels = weight[0]
- kernel_dims = list(weight[2:])
- input_dims = list(input[2:])
- length = len(input_dims)
- paddings = (
- padding
- if isinstance(padding, list)
- else [
- padding,
- ]
- * length
- )
- strides = (
- stride
- if isinstance(stride, list)
- else [
- stride,
- ]
- * length
- )
- dilations = (
- dilation
- if isinstance(dilation, list)
- else [
- dilation,
- ]
- * length
- )
- output_dims = []
- for idx, input_dim in enumerate(input_dims):
- output_dim = (
- input_dim
- + 2 * paddings[idx]
- - (dilations[idx] * (kernel_dims[idx] - 1) + 1)
- ) // strides[idx] + 1
- output_dims.append(output_dim)
- filters_per_channel = out_channels // groups
- macs_conv_per_position = (
- prod(kernel_dims) * in_channels * filters_per_channel
- )
- active_elements = batch_size * prod(output_dims)
- overall_conv_macs = macs_conv_per_position * active_elements
- overall_conv_flops = 2 * overall_conv_macs
- overall_bias_flops = 0
- if bias is not None:
- overall_bias_flops = out_channels * active_elements
- return overall_conv_flops + overall_bias_flops
- @register_flops("dropout")
- def _dropout_flops(input_shapes, attrs):
- """FLOPs computation for dropout op.
- For dropout(input):
- equation: flops = 0
- """
- return 0
- def _elementwise_flops_compute(input_shapes, attrs):
- input_x = input_shapes.get("X")[0]
- input_y = input_shapes.get("Y")[0]
- dim_x = len(input_x)
- dim_y = len(input_y)
- dim_output = max(dim_x, dim_y)
- output = []
- for i in range(dim_output):
- in_x = input_x[dim_x - 1 - i] if i < dim_x else 1
- in_y = input_y[dim_y - 1 - i] if i < dim_y else 1
- output.append(max(in_x, in_y))
- return prod(output)
- @register_flops("elementwise_add")
- def _elementwise_add_flops(input_shapes, attrs):
- """FLOPs computation for elementwise_add op.
- For elementwise_add(input,other):
- input_shapes = [shape_of_input, shape_of_other]
- shape_of_input = [dim1, dim2, dim3 ...]
- shape_of_other = [odim1, odim2, odim3...]
- equation: flops = max(dim1, odim1) * max(dim2, odim2) * max()...
- """
- return _elementwise_flops_compute(input_shapes, attrs)
- @register_flops("elementwise_mul")
- def _elementwise_mul_flops(input_shapes, attrs):
- """FLOPs computation for elementwise_mul op.
- For elementwise_mul(input,other):
- input_shapes = [shape_of_input, shape_of_other]
- shape_of_input = [dim1, dim2, dim3 ...]
- shape_of_other = [odim1, odim2, odim3...]
- equation: flops = max(dim1, odim1) * max(dim2, odim2)* max()...
- """
- return _elementwise_flops_compute(input_shapes, attrs)
- @register_flops("elementwise_div")
- def _elementwise_div_flops(input_shapes, attrs):
- """FLOPs computation for elementwise_div op.
- For elementwise_div(input,other):
- input_shapes = [shape_of_input, shape_of_other]
- shape_of_input = [dim1, dim2, dim3 ...]
- shape_of_other = [odim1, odim2, odim3...]
- equation: flops = max(dim1,odim1)*max(dim2,odim2)*max()...
- """
- return _elementwise_flops_compute(input_shapes, attrs)
- @register_flops("gelu")
- def _gelu_flops(input_shapes, attrs):
- """FLOPs computation for gelu op.
- For gelu(input):
- equation: flops = 5 * (numel)total number of elements in the input tensor.
- """
- input = input_shapes.get('X')[0]
- return prod(input) * 5
- @register_flops("layer_norm")
- def _layer_norm_flops(input_shapes, attrs):
- """FLOPs computation for layer_norm op.
- For layer_norm(input):
- equation:
- 1): WITHOUT epsilon flops = 7 * (numel)total number of elements in the input tensor.
- 2): WITH epsilon flops = 8 * (numel)total number of elements in the input tensor.
- """
- input = input_shapes.get('X')[0]
- flops = prod(input) * 7
- if attrs.get('epsilon'):
- flops += prod(input)
- return flops
- @register_flops("matmul")
- def _matmul_flops(input_shapes, attrs):
- """FLOPs computation for matmul op.
- For matmul(input,other):
- input_shapes = [shape_of_input, shape_of_other]
- shape_of_input = [dim1,dim2 ...dim_n_1,dim_n] length:n
- shape_of_other = [odim1,odim2 ... odim(n-m)... odim_m_1,dim_m] length:m
- suppose n > m and dim_n = odim_m_1:
- shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1)) ... dim_n_1, dim_m]
- equation: flops = 2 * numel(output) * dim_n
- """
- x_shape = copy.deepcopy(
- input_shapes.get("X", input_shapes.get("x", [[0]]))[0]
- )
- y_shape = copy.deepcopy(
- input_shapes.get("Y", input_shapes.get("y", [[0]]))[0]
- )
- if attrs.get('transpose_X') or attrs.get('transpose_x'):
- x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
- if attrs.get('transpose_Y') or attrs.get('transpose_y'):
- y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
- dim_x = len(x_shape)
- dim_y = len(y_shape)
- output_len = max(dim_x, dim_y)
- output_shape = []
- for idx in range(output_len, 2, -1):
- x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1
- y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1
- output_shape.append(max(x_idx, y_idx))
- macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1]
- return 2 * macs
- @register_flops("matmul_v2")
- def _matmul_v2_flops(input_shapes, attrs):
- """FLOPs computation for matmul_v2 op.
- For matmul_v2(input,other):
- input_shapes = [shape_of_input, shape_of_other]
- shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n
- shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
- suppose n > m and dim_n = odim_m_1:
- shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
- equation: flops = 2 * numel(outputs) * dim_n
- """
- x_shape = copy.deepcopy(input_shapes.get('X')[0])
- y_shape = copy.deepcopy(input_shapes.get('Y')[0])
- if attrs.get('trans_x'):
- x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
- if attrs.get('trans_y'):
- y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
- dim_x = len(x_shape)
- dim_y = len(y_shape)
- output_len = max(dim_x, dim_y)
- output_shape = []
- for idx in range(output_len, 2, -1):
- x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1
- y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1
- output_shape.append(max(x_idx, y_idx))
- macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1]
- return 2 * macs
- def _relu_class_flops(input_shapes, attrs):
- """FLOPs computation for relu_like ops.
- For elu/leaky_relu/prelu/relu/relu6/silu (input):
- equation: flops = (numel)total number of elements in the input tensor.
- """
- input = input_shapes.get('X')[0]
- return prod(input)
- @register_flops("elu")
- def _elu_flops(input_shapes, attrs):
- return _relu_class_flops(input_shapes, attrs)
- @register_flops("leaky_relu")
- def _leaky_relu_flops(input_shapes, attrs):
- return _relu_class_flops(input_shapes, attrs)
- @register_flops("prelu")
- def _prelu_flops(input_shapes, attrs):
- return _relu_class_flops(input_shapes, attrs)
- @register_flops("relu")
- def _relu_flops(input_shapes, attrs):
- return _relu_class_flops(input_shapes, attrs)
- @register_flops("relu6")
- def _relu6_flops(input_shapes, attrs):
- return _relu_class_flops(input_shapes, attrs)
- @register_flops("silu")
- def _silu_flops(input_shapes, attrs):
- return _relu_class_flops(input_shapes, attrs)
- @register_flops("reshape2")
- def _reshape2_flops(input_shapes, attrs):
- """FLOPs computation for reshape2 op.
- For reshape2(input):
- equation: flops = 0
- """
- return 0
- @register_flops("softmax")
- def _softmax_flops(input_shapes, attrs):
- """FLOPs computation for softmax op.
- For softmax(input):
- equation: flops = 3 * (numel)total number of elements in the input tensor.
- """
- input = input_shapes.get('X')[0]
- return prod(input) * 3
- @register_flops("transpose2")
- def _transpose2_flops(input_shapes, attrs):
- """FLOPs computation for transpose2 op.
- For transpose2(input):
- equation: flops = 0
- """
- return 0
- @register_flops("pool")
- def _pool_flops(input_shapes, attrs):
- """FLOPs computation for pool op.
- For pool(input):
- equation: flops = (numel)total number of elements in the input tensor.
- """
- input = input_shapes.get('X')[0]
- return prod(input)
|