utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import numpy as np
  16. from ...base.framework import IrNode, Operator
  17. from .quant_config import SUPPORT_QUANTIZATION_OP_DICT
  18. _channelwise_quant_axis1_ops = [
  19. 'conv2d_transpose',
  20. 'mul',
  21. 'matmul',
  22. 'matmul_v2',
  23. ]
  24. def _get_op_input_var_names(op):
  25. """
  26. Get the input var names of the op.
  27. Args:
  28. op(IrNode, Operator): the input op.
  29. Returns:
  30. input_var_names or None.
  31. """
  32. assert isinstance(
  33. op, (IrNode, Operator)
  34. ), "The input op should be IrNode or Operator."
  35. var_names = []
  36. op_name = op.name() if isinstance(op, IrNode) else op.type
  37. if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
  38. return []
  39. name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][0]
  40. for name in name_list:
  41. var_name = op.input(name)
  42. if isinstance(var_name, list):
  43. var_names.extend(var_name)
  44. else:
  45. var_names.append(var_name)
  46. return var_names
  47. def _get_op_output_var_names(op):
  48. """ """
  49. assert isinstance(
  50. op, (IrNode, Operator)
  51. ), "The input op should be IrNode or Operator."
  52. var_names = []
  53. op_name = op.name() if isinstance(op, IrNode) else op.type
  54. if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
  55. return []
  56. name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
  57. for name in name_list:
  58. var_name = op.output(name)
  59. if isinstance(var_name, list):
  60. var_names.extend(var_name)
  61. else:
  62. var_names.append(var_name)
  63. return var_names
  64. def _get_input_name_index(op, input_var_name):
  65. """Get the input name and index of the var_name in the op"""
  66. assert isinstance(
  67. op, (IrNode, Operator)
  68. ), "The input op should be IrNode or Operator."
  69. op_name = op.name() if isinstance(op, IrNode) else op.type
  70. if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
  71. return None
  72. res = None
  73. for argname in SUPPORT_QUANTIZATION_OP_DICT[op_name][0]:
  74. var_names = op.input(argname)
  75. for index, name in enumerate(var_names):
  76. if name == input_var_name:
  77. res = (argname, index)
  78. return res
  79. def _get_output_name_index(op, output_var_name):
  80. """Get the output name and index of the var_name in the op"""
  81. assert isinstance(
  82. op, (IrNode, Operator)
  83. ), "The input op should be IrNode or Operator."
  84. op_name = op.name() if isinstance(op, IrNode) else op.type
  85. if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
  86. return None
  87. name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
  88. res = None
  89. for name in name_list:
  90. var_name = op.output(name)
  91. for index, val in enumerate(var_name):
  92. if val == output_var_name:
  93. res = (name, index)
  94. return res
  95. def load_variable_data(scope, var_name):
  96. '''
  97. Load variable value from scope
  98. '''
  99. var_node = scope.find_var(var_name)
  100. assert var_node is not None, "Cannot find " + var_name + " in scope."
  101. return np.array(var_node.get_tensor())
  102. def set_variable_data(scope, place, var_name, np_value):
  103. '''
  104. Set the value of var node by name, if the node exits,
  105. '''
  106. assert isinstance(
  107. np_value, np.ndarray
  108. ), 'The type of value should be numpy array.'
  109. var_node = scope.find_var(var_name)
  110. if var_node is not None:
  111. tensor = var_node.get_tensor()
  112. tensor.set(np_value, place)
  113. def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
  114. # symmetry quant
  115. def _clip(x, scale):
  116. x[x > scale] = scale
  117. x[x < -scale] = -scale
  118. return x
  119. bnt = (1 << (weight_bits - 1)) - 1
  120. if isinstance(scale, list) and len(scale) == 1:
  121. scale = scale[0]
  122. if isinstance(scale, list):
  123. assert quant_axis in [-1, 0, 1], 'quant_axis should be 0 or 1 for now.'
  124. for i, s in enumerate(scale):
  125. if s == 0.0:
  126. s = 1e-8
  127. if quant_axis == 0:
  128. if onnx_format:
  129. x[i] = np.round(x[i] / s * bnt)
  130. x[i] = np.clip(x[i], -bnt - 1, bnt)
  131. else:
  132. x[i] = _clip(x[i], s)
  133. x[i] = x[i] / s * bnt
  134. else:
  135. if onnx_format:
  136. x[:, i] = np.round(x[:, i] / s * bnt)
  137. x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
  138. else:
  139. x[:, i] = _clip(x[:, i], s)
  140. x[:, i] = x[:, i] / s * bnt
  141. else:
  142. scale = 1e-8 if scale == 0.0 else scale
  143. if onnx_format:
  144. x = np.round(x / scale * bnt)
  145. x = np.clip(x, -bnt - 1, bnt)
  146. else:
  147. x = _clip(x, scale)
  148. x = x / scale * bnt
  149. return x
  150. def dequant_tensor(x, scale, quant_axis=0, weight_bits=8):
  151. assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
  152. bnt = (1 << (weight_bits - 1)) - 1
  153. if isinstance(scale, list):
  154. for i, s in enumerate(scale):
  155. if s == 0.0:
  156. s = 1e-8
  157. if quant_axis == 0:
  158. x[i] = x[i] * s / bnt
  159. else:
  160. x[:, i] = x[:, i] * s / bnt
  161. else:
  162. scale = 1e-8 if scale == 0.0 else scale
  163. x = x * scale / bnt
  164. return x
  165. def bias_correction_w(x, x_quant, scale_v, quant_axis, weight_bits=8):
  166. '''
  167. Bias correction for weight
  168. '''
  169. eps = 1e-8
  170. bnt = (1 << (weight_bits - 1)) - 1
  171. x_dequant = x_quant.copy()
  172. if isinstance(scale_v, list):
  173. if quant_axis == 0:
  174. for i, s in enumerate(scale_v):
  175. x_dequant[i] = x_dequant[i] * s / bnt
  176. quant_bias = x - x_dequant
  177. mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1)
  178. std_orig = x.reshape(x.shape[0], -1).std(-1)
  179. std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1)
  180. std_bias = std_orig / (std_quant + eps)
  181. else:
  182. for i, s in enumerate(scale_v):
  183. x_dequant[:, i] = x_quant[:, i] * s / bnt
  184. quant_bias = x - x_dequant
  185. mean_bias = np.array(
  186. [quant_bias[:, i].mean() for i in range(quant_bias.shape[1])]
  187. )
  188. std_orig = np.array([x[:, i].std() for i in range(x.shape[1])])
  189. std_quant = np.array(
  190. [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]
  191. )
  192. std_bias = std_orig / (std_quant + eps)
  193. else:
  194. x_dequant = x_quant * scale_v / bnt
  195. mean_bias = (x - x_dequant).mean()
  196. std_bias = x.std() / (x_dequant.std() + eps)
  197. if mean_bias.ndim == 1:
  198. std_bias = np.resize(std_bias, x.shape)
  199. mean_bias = np.resize(mean_bias, x.shape)
  200. x_dequant = (mean_bias + x_dequant) * std_bias
  201. quantized_param_v = quant_tensor(
  202. x_dequant, scale_v, quant_axis, weight_bits
  203. )
  204. return quantized_param_v
  205. def stable_sigmoid(x):
  206. sig = np.where(x < 0, np.exp(x) / (1 + np.exp(x)), 1 / (1 + np.exp(-x)))
  207. return sig
  208. def calculate_quant_cos_error(orig_tensor, qdq_tensor):
  209. cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) / (
  210. np.linalg.norm(orig_tensor.flatten())
  211. * np.linalg.norm(qdq_tensor.flatten())
  212. )
  213. return cos_sim
  214. def move_persistable_var_to_global_block(program):
  215. # Move sub blocks persistable var to global block
  216. global_block = program.global_block()
  217. for _op in global_block.ops:
  218. if _op.type == "while":
  219. _block_id = _op.attr("sub_block").id
  220. _block = program.block(_block_id)
  221. persistables = []
  222. for _name, _var in _block.vars.items():
  223. if _var.persistable:
  224. global_block._clone_variable(_var)
  225. persistables.append(_name)
  226. for _name in persistables:
  227. _block._remove_var(_name)
  228. persistables.extend(_op.input('X'))
  229. _op.desc.set_input("X", persistables)
  230. def l2_loss(gt, pred):
  231. return ((gt - pred) ** 2).mean()
  232. class tqdm:
  233. def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
  234. self.total = total
  235. self.bar_format = bar_format
  236. self.ncols = ncols
  237. self.n = 0
  238. def update(self, n=1):
  239. self.n += n
  240. a = "=" * round((self.n / self.total) * self.ncols)
  241. b = " " * (self.ncols - len(a))
  242. prefix = self.bar_format.split('|')[0]
  243. sys.stderr.write(f"\r{prefix}|{a}=>{b}| {self.n}/{self.total}")
  244. sys.stderr.flush()
  245. def __enter__(self):
  246. return self
  247. def __exit__(self, exc_type, exc_val, exc_tb):
  248. sys.stderr.write('\n')