ptq.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import logging
  16. import os
  17. import numpy as np
  18. import paddle
  19. from paddle.nn.quant import quant_layers
  20. from ...static.log_helper import get_logger
  21. from ...static.quantization.utils import (
  22. _get_input_name_index,
  23. _get_op_input_var_names,
  24. _get_op_output_var_names,
  25. _get_output_name_index,
  26. )
  27. from . import fuse_utils, ptq_config, ptq_hooks, ptq_quantizer, utils
  28. from .ptq_registry import PTQRegistry
  29. INFER_MODEL_SUFFIX = ".pdmodel"
  30. INFER_PARAMS_SUFFIX = ".pdiparams"
  31. _logger = get_logger(
  32. __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
  33. )
  34. class ImperativePTQ:
  35. """
  36. Static post training quantization.
  37. """
  38. def __init__(self, quant_config=ptq_config.default_ptq_config):
  39. """
  40. Constructor.
  41. Args:
  42. quant_config(PTQConfig): the config of post training quantization.
  43. The config has weight_quantizer and activation_quantizer.
  44. In default, the weight_quantizer is PerChannelAbsmaxQuantizer
  45. and the activation_quantizer is KLQuantizer.
  46. """
  47. super().__init__()
  48. assert isinstance(quant_config, ptq_config.PTQConfig)
  49. self._quant_config = quant_config
  50. def quantize(self, model, inplace=False, fuse=False, fuse_list=None):
  51. """
  52. Add quant config and hook to the target layer.
  53. Args:
  54. model(paddle.nn.Layer): The model to be quantized.
  55. inplace(bool): Whether apply quantization to the input model.
  56. Default: False.
  57. fuse(bool): Whether to fuse layers.
  58. Default: False.
  59. fuse_list(list): The layers' names to be fused. For example,
  60. "fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
  61. A TypeError would be raised if "fuse" was set as
  62. True but "fuse_list" was None.
  63. Default: None.
  64. Return
  65. quantized_model(paddle.nn.Layer): The quantized model.
  66. """
  67. assert isinstance(
  68. model, paddle.nn.Layer
  69. ), "The model must be the instance of paddle.nn.Layer."
  70. if not inplace:
  71. model = copy.deepcopy(model)
  72. if fuse:
  73. model.eval()
  74. model = fuse_utils.fuse_layers(model, fuse_list)
  75. for name, layer in model.named_sublayers():
  76. if (
  77. PTQRegistry.is_supported_layer(layer)
  78. and utils.is_leaf_layer(layer)
  79. and not self._is_skip_layer(layer)
  80. ):
  81. # Add quant config
  82. quant_config = copy.deepcopy(self._quant_config)
  83. if PTQRegistry.is_simulated_quant_layer(layer):
  84. quant_config.enable_in_act_quantizer = True
  85. layer._quant_config = quant_config
  86. # register hook
  87. hook = ptq_hooks.quant_forward_post_hook
  88. quant_hook_handle = layer.register_forward_post_hook(hook)
  89. quant_config.quant_hook_handle = quant_hook_handle
  90. layer._forward_post_hooks.move_to_end(
  91. quant_hook_handle._hook_id, last=False
  92. )
  93. return model
  94. def save_quantized_model(self, model, path, input_spec=None, **config):
  95. """
  96. 1. Convert the quantized model
  97. 2. Call jit.save to save the inference model
  98. 3. Post process the inference model.
  99. Args:
  100. model (Layer): The model to be saved.
  101. path (str): The path prefix to save model. The format is
  102. ``dirname/file_prefix`` or ``file_prefix``.
  103. input_spec (list[InputSpec|Tensor], optional): Describes the input
  104. of the saved model's forward method, which can be described by
  105. InputSpec or example Tensor. If None, all input variables of
  106. the original Layer's forward method would be the inputs of
  107. the saved model. Default None.
  108. **config (dict, optional): Other save configuration options for
  109. compatibility. We do not recommend using these configurations,
  110. they may be removed in the future. If not necessary, DO NOT use
  111. them. Default None.
  112. The following options are currently supported:
  113. (1) output_spec (list[Tensor]): Selects the output targets of
  114. the saved model. By default, all return variables of original
  115. Layer's forward method are kept as the output of the saved model.
  116. If the provided ``output_spec`` list is not all output variables,
  117. the saved model will be pruned according to the given
  118. ``output_spec`` list.
  119. Returns:
  120. None
  121. """
  122. assert isinstance(
  123. model, paddle.nn.Layer
  124. ), "The model must be the instance of paddle.nn.Layer."
  125. # Convert and save dygraph quantized model
  126. self._convert(model)
  127. paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config)
  128. # Load inference program
  129. is_dynamic_mode = False
  130. if paddle.in_dynamic_mode():
  131. is_dynamic_mode = True
  132. paddle.enable_static()
  133. place = paddle.CPUPlace()
  134. scope = paddle.static.global_scope()
  135. exe = paddle.static.Executor(place)
  136. dirname = os.path.dirname(path)
  137. basename = os.path.basename(path)
  138. model_filename = basename + INFER_MODEL_SUFFIX
  139. params_filename = basename + INFER_PARAMS_SUFFIX
  140. [
  141. infer_program,
  142. feed_target_names,
  143. fetch_targets,
  144. ] = paddle.static.load_inference_model(
  145. path_prefix=dirname,
  146. executor=exe,
  147. model_filename=model_filename,
  148. params_filename=params_filename,
  149. )
  150. # Process inference program
  151. self._clean_up(infer_program)
  152. self._gather_input_thresholds(infer_program, scope)
  153. self._remove_scale_op(infer_program)
  154. # Save final program
  155. model_name = None
  156. if model_filename is None:
  157. model_name = "model"
  158. elif model_filename.endswith(".pdmodel"):
  159. model_name = model_filename.rsplit(".", 1)[0]
  160. else:
  161. model_name = model_filename
  162. path_prefix = os.path.join(dirname, model_name)
  163. feed_vars = [
  164. infer_program.global_block().var(name) for name in feed_target_names
  165. ]
  166. paddle.static.save_inference_model(
  167. path_prefix,
  168. feed_vars,
  169. fetch_targets,
  170. executor=exe,
  171. program=infer_program.clone(),
  172. )
  173. if is_dynamic_mode:
  174. paddle.disable_static()
  175. def _convert(self, model):
  176. """
  177. Convert the quantized model.
  178. Args:
  179. model(paddle.nn.Layer): The quantized model.
  180. inplace(bool): Whether apply conversion to the input model.
  181. Default: False.
  182. Returns:
  183. None
  184. """
  185. for name, sub_layer in model.named_sublayers():
  186. if self._is_quant_layer(sub_layer):
  187. sub_layer._quant_config.quant_hook_handle.remove()
  188. self._cal_thresholds(model)
  189. for name, sub_layer in model.named_sublayers():
  190. if self._is_quant_layer(sub_layer):
  191. self._save_output_thresholds(sub_layer, sub_layer._quant_config)
  192. self._wrap_simulated_layers(model)
  193. def _cal_thresholds(self, model):
  194. """
  195. Calculate the thresholds of inputs and outputs.
  196. Args:
  197. model(paddle.nn.Layer): The quantized model.
  198. Returns:
  199. None
  200. """
  201. assert isinstance(
  202. model, paddle.nn.Layer
  203. ), "The input model must be the instance of paddle.nn.Layer."
  204. total_num = 0
  205. cur_num = 0
  206. for name, sub_layer in model.named_sublayers():
  207. if self._is_quant_layer(sub_layer):
  208. total_num += 1
  209. for name, sub_layer in model.named_sublayers():
  210. if self._is_quant_layer(sub_layer):
  211. cur_num += 1
  212. if cur_num % 5 == 0:
  213. _logger.info(f"Process the {cur_num} / {total_num} layer")
  214. quant_config = sub_layer._quant_config
  215. if quant_config.enable_in_act_quantizer:
  216. quant_config.in_act_quantizer.cal_thresholds()
  217. quant_config.out_act_quantizer.cal_thresholds()
  218. if PTQRegistry.is_simulated_quant_layer(sub_layer):
  219. weights = (sub_layer.weight,)
  220. quant_config.wt_quantizer.sample_data(sub_layer, weights)
  221. quant_config.wt_quantizer.cal_thresholds()
  222. def _save_output_thresholds(self, sub_layer, quant_config):
  223. """
  224. Save the output thresholds to the layer.
  225. Args:
  226. sub_layer(paddle.nn.Layer): The quantized layer.
  227. quant_config(PTQConfig): the quant config for the layer.
  228. Returns:
  229. None
  230. """
  231. assert isinstance(
  232. sub_layer, paddle.nn.Layer
  233. ), "The input model must be the instance of paddle.nn.Layer."
  234. layer_info = PTQRegistry.layer_info(sub_layer)
  235. output_names = layer_info.output_names
  236. output_thresholds = quant_config.out_act_quantizer.thresholds
  237. assert len(output_names) == 1
  238. if len(output_thresholds) == 1:
  239. save_name = output_names[0] + str(0) + "_threshold"
  240. sub_layer._set_op_attrs({save_name: output_thresholds[0]})
  241. sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]})
  242. else:
  243. _logger.warning(
  244. f"output_thresholds shape of {output_names[0]} need to be 1, but received {len(output_thresholds)}"
  245. )
  246. def _wrap_simulated_layers(self, model):
  247. """
  248. Replace conv2d and linear with the quantized layers, and save
  249. thresholds into the fake layers.
  250. Args:
  251. model(paddle.nn.Layer): The model to be quantized.
  252. Returns:
  253. None
  254. """
  255. assert isinstance(
  256. model, paddle.nn.Layer
  257. ), "The input model must be the instance of paddle.nn.Layer."
  258. for name, sub_layer in model.named_sublayers():
  259. if self._is_quant_layer(
  260. sub_layer
  261. ) and PTQRegistry.is_simulated_quant_layer(sub_layer):
  262. quant_config = sub_layer._quant_config
  263. assert quant_config.enable_in_act_quantizer is True
  264. wt_quantizer = quant_config.wt_quantizer
  265. in_act_quantizer = quant_config.in_act_quantizer
  266. # create layer
  267. quant_layer_name = None
  268. for key, value in utils.layer_name_map.items():
  269. if isinstance(sub_layer, value):
  270. quant_layer_name = 'Quantized' + key
  271. break
  272. assert quant_layer_name is not None
  273. if isinstance(wt_quantizer, ptq_quantizer.AbsmaxQuantizer):
  274. weight_quantize_type = "abs_max"
  275. else:
  276. weight_quantize_type = "channel_wise_abs_max"
  277. kwargs = {
  278. "weight_quantize_type": weight_quantize_type,
  279. "activation_quantize_type": "moving_average_abs_max",
  280. "weight_bits": wt_quantizer.quant_bits,
  281. "activation_bits": in_act_quantizer.quant_bits,
  282. }
  283. quant_layer = quant_layers.__dict__[quant_layer_name](
  284. sub_layer, **kwargs
  285. )
  286. # save the input thresholds
  287. assert hasattr(quant_layer, "_fake_quant_input")
  288. assert hasattr(quant_layer._fake_quant_input, "_scale")
  289. if len(in_act_quantizer.thresholds) == 1:
  290. input_threshold = np.array(
  291. [in_act_quantizer.thresholds[0]], dtype=np.float32
  292. )
  293. quant_layer._fake_quant_input._scale.set_value(
  294. input_threshold
  295. )
  296. assert hasattr(quant_layer, "_fake_quant_weight")
  297. assert hasattr(quant_layer._fake_quant_weight, "_scale")
  298. assert len(wt_quantizer.thresholds) == 1
  299. weight_threshold = wt_quantizer.thresholds[0]
  300. if isinstance(weight_threshold, list):
  301. weight_threshold = np.array(
  302. weight_threshold, dtype=np.float32
  303. )
  304. else:
  305. weight_threshold = np.array(
  306. [weight_threshold], dtype=np.float32
  307. )
  308. quant_layer._fake_quant_weight._scale.set_value(
  309. weight_threshold
  310. )
  311. # save the output thresholds
  312. self._save_output_thresholds(quant_layer, quant_config)
  313. # replace the layer
  314. parent_layer, sub_name = utils.find_parent_layer_and_sub_name(
  315. model, name
  316. )
  317. setattr(parent_layer, sub_name, quant_layer)
  318. def _gather_input_thresholds(self, program, scope):
  319. """
  320. Get and save input thresholds from the front ops.
  321. Args:
  322. program(Program): the input infer program.
  323. scope(Scope): the corresponding scope for the program.
  324. Returns:
  325. None
  326. """
  327. for op in utils.program_all_ops(program):
  328. for in_var_name in _get_op_input_var_names(op):
  329. previous_op = utils.find_previous_op(op.block, in_var_name)
  330. if previous_op is None:
  331. continue
  332. if (
  333. "quantize_dequantize" in previous_op.type
  334. or previous_op.type == "moving_average_abs_max_scale"
  335. ):
  336. attr_name = previous_op.output('OutScale')[0]
  337. in_threshold = utils.load_variable_data(scope, attr_name)
  338. in_threshold = utils.fp_numpy_to_naive(in_threshold)
  339. argname, index = _get_input_name_index(op, in_var_name)
  340. op._set_attr(
  341. argname + str(index) + "_threshold", in_threshold
  342. )
  343. op._set_attr("with_quant_attr", True)
  344. else:
  345. for out_var_name in _get_op_output_var_names(previous_op):
  346. if out_var_name != in_var_name:
  347. continue
  348. argname, index = _get_output_name_index(
  349. previous_op, out_var_name
  350. )
  351. attr_name = argname + str(index) + "_threshold"
  352. if not previous_op.has_attr(attr_name):
  353. continue
  354. threshold = previous_op.attr(attr_name)
  355. argname, index = _get_input_name_index(op, in_var_name)
  356. attr_name = argname + str(index) + "_threshold"
  357. op._set_attr(attr_name, threshold)
  358. op._set_attr("with_quant_attr", True)
  359. def _clean_up(self, program):
  360. """
  361. Remove useless thresholds which are added in jit.save.
  362. Args:
  363. program(Program): the input infer program.
  364. Returns:
  365. None
  366. """
  367. def _helper(op, next_op, old_attr_name, new_attr_name):
  368. if (
  369. op.has_attr(old_attr_name)
  370. and next_op.has_attr(old_attr_name)
  371. and op.attr(old_attr_name) == next_op.attr(old_attr_name)
  372. ):
  373. threshold = op.attr(old_attr_name)
  374. op._remove_attr(old_attr_name)
  375. next_op._remove_attr(old_attr_name)
  376. next_op._set_attr(new_attr_name, threshold)
  377. next_op._set_attr("with_quant_attr", True)
  378. for op in utils.program_all_ops(program):
  379. if "quantize_dequantize" in op.type:
  380. # remove the thresholds in fake ops
  381. for attr_name in op.attr_names:
  382. if "_threshold" in attr_name:
  383. op._remove_attr(attr_name)
  384. elif op.type in ["conv2d", "matmul"]:
  385. # change the thresholds in conv2d/matmul + eleadd
  386. arg_name = "Output" if op.type == "conv2d" else "Out"
  387. out_var_name = op.output(arg_name)[0]
  388. next_ops = utils.find_next_ops(op.block, out_var_name)
  389. if len(next_ops) > 1 or next_ops[0].type != "elementwise_add":
  390. continue
  391. next_op = next_ops[0]
  392. argname, index = _get_output_name_index(op, out_var_name)
  393. old_attr_name = argname + str(index) + "_threshold"
  394. argname, index = _get_output_name_index(
  395. next_op, next_op.output("Out")[0]
  396. )
  397. new_attr_name = argname + str(index) + "_threshold"
  398. _helper(op, next_op, old_attr_name, new_attr_name)
  399. _helper(op, next_op, "out_threshold", "out_threshold")
  400. def _remove_scale_op(self, program):
  401. """
  402. Remove the moving_average_abs_max_scale op.
  403. """
  404. for op in utils.program_all_ops(program):
  405. if op.type == "moving_average_abs_max_scale":
  406. in_var_name = op.input("X")[0]
  407. out_var_name = op.output("Out")[0]
  408. next_ops = utils.find_next_ops(op.block, out_var_name)
  409. for next_op in next_ops:
  410. next_op._rename_input(out_var_name, in_var_name)
  411. @staticmethod
  412. def _is_skip_layer(layer):
  413. return hasattr(layer, "skip_quant") and layer.skip_quant is True
  414. @staticmethod
  415. def _is_quant_layer(layer):
  416. return hasattr(layer, "_quant_config")