quant_kl.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.abspath(os.path.join(__dir__, "..", "..", "..")))
  22. sys.path.append(os.path.abspath(os.path.join(__dir__, "..", "..", "..", "tools")))
  23. import yaml
  24. import paddle
  25. import paddle.distributed as dist
  26. paddle.seed(2)
  27. from ppocr.data import build_dataloader, set_signal_handlers
  28. from ppocr.modeling.architectures import build_model
  29. from ppocr.losses import build_loss
  30. from ppocr.optimizer import build_optimizer
  31. from ppocr.postprocess import build_post_process
  32. from ppocr.metrics import build_metric
  33. from ppocr.utils.save_load import load_model
  34. import tools.program as program
  35. import paddleslim
  36. from paddleslim.dygraph.quant import QAT
  37. import numpy as np
  38. dist.get_world_size()
  39. class PACT(paddle.nn.Layer):
  40. def __init__(self):
  41. super(PACT, self).__init__()
  42. alpha_attr = paddle.ParamAttr(
  43. name=self.full_name() + ".pact",
  44. initializer=paddle.nn.initializer.Constant(value=20),
  45. learning_rate=1.0,
  46. regularizer=paddle.regularizer.L2Decay(2e-5),
  47. )
  48. self.alpha = self.create_parameter(shape=[1], attr=alpha_attr, dtype="float32")
  49. def forward(self, x):
  50. out_left = paddle.nn.functional.relu(x - self.alpha)
  51. out_right = paddle.nn.functional.relu(-self.alpha - x)
  52. x = x - out_left + out_right
  53. return x
  54. quant_config = {
  55. # weight preprocess type, default is None and no preprocessing is performed.
  56. "weight_preprocess_type": None,
  57. # activation preprocess type, default is None and no preprocessing is performed.
  58. "activation_preprocess_type": None,
  59. # weight quantize type, default is 'channel_wise_abs_max'
  60. "weight_quantize_type": "channel_wise_abs_max",
  61. # activation quantize type, default is 'moving_average_abs_max'
  62. "activation_quantize_type": "moving_average_abs_max",
  63. # weight quantize bit num, default is 8
  64. "weight_bits": 8,
  65. # activation quantize bit num, default is 8
  66. "activation_bits": 8,
  67. # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
  68. "dtype": "int8",
  69. # window size for 'range_abs_max' quantization. default is 10000
  70. "window_size": 10000,
  71. # The decay coefficient of moving average, default is 0.9
  72. "moving_rate": 0.9,
  73. # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
  74. "quantizable_layer_type": ["Conv2D", "Linear"],
  75. }
  76. def sample_generator(loader):
  77. def __reader__():
  78. for _, data in enumerate(loader):
  79. images = np.array(data[0])
  80. yield images
  81. return __reader__
  82. def sample_generator_layoutxlm_ser(loader):
  83. def __reader__():
  84. for _, data in enumerate(loader):
  85. input_ids = np.array(data[0])
  86. bbox = np.array(data[1])
  87. attention_mask = np.array(data[2])
  88. token_type_ids = np.array(data[3])
  89. images = np.array(data[4])
  90. yield [input_ids, bbox, attention_mask, token_type_ids, images]
  91. return __reader__
  92. def main(config, device, logger, vdl_writer):
  93. # init dist environment
  94. if config["Global"]["distributed"]:
  95. dist.init_parallel_env()
  96. global_config = config["Global"]
  97. # build dataloader
  98. set_signal_handlers()
  99. config["Train"]["loader"]["num_workers"] = 0
  100. is_layoutxlm_ser = (
  101. config["Architecture"]["model_type"] == "kie"
  102. and config["Architecture"]["Backbone"]["name"] == "LayoutXLMForSer"
  103. )
  104. train_dataloader = build_dataloader(config, "Train", device, logger)
  105. if config["Eval"]:
  106. config["Eval"]["loader"]["num_workers"] = 0
  107. valid_dataloader = build_dataloader(config, "Eval", device, logger)
  108. if is_layoutxlm_ser:
  109. train_dataloader = valid_dataloader
  110. else:
  111. valid_dataloader = None
  112. paddle.enable_static()
  113. exe = paddle.static.Executor(device)
  114. if "inference_model" in global_config.keys(): # , 'inference_model'):
  115. inference_model_dir = global_config["inference_model"]
  116. else:
  117. inference_model_dir = os.path.dirname(global_config["pretrained_model"])
  118. if not (
  119. os.path.exists(os.path.join(inference_model_dir, "inference.pdmodel"))
  120. and os.path.exists(os.path.join(inference_model_dir, "inference.pdiparams"))
  121. ):
  122. raise ValueError(
  123. "Please set inference model dir in Global.inference_model or Global.pretrained_model for post-quantization"
  124. )
  125. if is_layoutxlm_ser:
  126. generator = sample_generator_layoutxlm_ser(train_dataloader)
  127. else:
  128. generator = sample_generator(train_dataloader)
  129. paddleslim.quant.quant_post_static(
  130. executor=exe,
  131. model_dir=inference_model_dir,
  132. model_filename="inference.pdmodel",
  133. params_filename="inference.pdiparams",
  134. quantize_model_path=global_config["save_inference_dir"],
  135. sample_generator=generator,
  136. save_model_filename="inference.pdmodel",
  137. save_params_filename="inference.pdiparams",
  138. batch_size=1,
  139. batch_nums=None,
  140. )
  141. if __name__ == "__main__":
  142. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  143. main(config, device, logger, vdl_writer)