sensitivity_anal.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(__file__)
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.join(__dir__, "..", "..", ".."))
  22. sys.path.append(os.path.join(__dir__, "..", "..", "..", "tools"))
  23. import paddle
  24. import paddle.distributed as dist
  25. from ppocr.data import build_dataloader, set_signal_handlers
  26. from ppocr.modeling.architectures import build_model
  27. from ppocr.losses import build_loss
  28. from ppocr.optimizer import build_optimizer
  29. from ppocr.postprocess import build_post_process
  30. from ppocr.metrics import build_metric
  31. from ppocr.utils.save_load import load_model
  32. import tools.program as program
  33. dist.get_world_size()
  34. def get_pruned_params(parameters):
  35. params = []
  36. for param in parameters:
  37. if (
  38. len(param.shape) == 4
  39. and "depthwise" not in param.name
  40. and "transpose" not in param.name
  41. and "conv2d_57" not in param.name
  42. and "conv2d_56" not in param.name
  43. ):
  44. params.append(param.name)
  45. return params
  46. def main(config, device, logger, vdl_writer):
  47. # init dist environment
  48. if config["Global"]["distributed"]:
  49. dist.init_parallel_env()
  50. global_config = config["Global"]
  51. # build dataloader
  52. set_signal_handlers()
  53. train_dataloader = build_dataloader(config, "Train", device, logger)
  54. if config["Eval"]:
  55. valid_dataloader = build_dataloader(config, "Eval", device, logger)
  56. else:
  57. valid_dataloader = None
  58. # build post process
  59. post_process_class = build_post_process(config["PostProcess"], global_config)
  60. # build model
  61. # for rec algorithm
  62. if hasattr(post_process_class, "character"):
  63. char_num = len(getattr(post_process_class, "character"))
  64. config["Architecture"]["Head"]["out_channels"] = char_num
  65. model = build_model(config["Architecture"])
  66. if config["Architecture"]["model_type"] == "det":
  67. input_shape = [1, 3, 640, 640]
  68. elif config["Architecture"]["model_type"] == "rec":
  69. input_shape = [1, 3, 32, 320]
  70. flops = paddle.flops(model, input_shape)
  71. logger.info("FLOPs before pruning: {}".format(flops))
  72. from paddleslim.dygraph import FPGMFilterPruner
  73. model.train()
  74. pruner = FPGMFilterPruner(model, input_shape)
  75. # build loss
  76. loss_class = build_loss(config["Loss"])
  77. # build optim
  78. optimizer, lr_scheduler = build_optimizer(
  79. config["Optimizer"],
  80. epochs=config["Global"]["epoch_num"],
  81. step_each_epoch=len(train_dataloader),
  82. model=model,
  83. )
  84. # build metric
  85. eval_class = build_metric(config["Metric"])
  86. # load pretrain model
  87. pre_best_model_dict = load_model(config, model, optimizer)
  88. logger.info(
  89. "train dataloader has {} iters, valid dataloader has {} iters".format(
  90. len(train_dataloader), len(valid_dataloader)
  91. )
  92. )
  93. # build metric
  94. eval_class = build_metric(config["Metric"])
  95. logger.info(
  96. "train dataloader has {} iters, valid dataloader has {} iters".format(
  97. len(train_dataloader), len(valid_dataloader)
  98. )
  99. )
  100. def eval_fn():
  101. metric = program.eval(
  102. model, valid_dataloader, post_process_class, eval_class, False
  103. )
  104. if config["Architecture"]["model_type"] == "det":
  105. main_indicator = "hmean"
  106. else:
  107. main_indicator = "acc"
  108. logger.info("metric[{}]: {}".format(main_indicator, metric[main_indicator]))
  109. return metric[main_indicator]
  110. run_sensitive_analysis = False
  111. """
  112. run_sensitive_analysis=True:
  113. Automatically compute the sensitivities of convolutions in a model.
  114. The sensitivity of a convolution is the losses of accuracy on test dataset in
  115. different pruned ratios. The sensitivities can be used to get a group of best
  116. ratios with some condition.
  117. run_sensitive_analysis=False:
  118. Set prune trim ratio to a fixed value, such as 10%. The larger the value,
  119. the more convolution weights will be cropped.
  120. """
  121. if run_sensitive_analysis:
  122. params_sensitive = pruner.sensitive(
  123. eval_func=eval_fn,
  124. sen_file="./deploy/slim/prune/sen.pickle",
  125. skip_vars=[
  126. "conv2d_57.w_0",
  127. "conv2d_transpose_2.w_0",
  128. "conv2d_transpose_3.w_0",
  129. ],
  130. )
  131. logger.info(
  132. "The sensitivity analysis results of model parameters saved in sen.pickle"
  133. )
  134. # calculate pruned params's ratio
  135. params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
  136. for key in params_sensitive.keys():
  137. logger.info("{}, {}".format(key, params_sensitive[key]))
  138. else:
  139. params_sensitive = {}
  140. for param in model.parameters():
  141. if "transpose" not in param.name and "linear" not in param.name:
  142. # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
  143. params_sensitive[param.name] = 0.1
  144. plan = pruner.prune_vars(params_sensitive, [0])
  145. flops = paddle.flops(model, input_shape)
  146. logger.info("FLOPs after pruning: {}".format(flops))
  147. # start train
  148. program.train(
  149. config,
  150. train_dataloader,
  151. valid_dataloader,
  152. device,
  153. model,
  154. loss_class,
  155. optimizer,
  156. lr_scheduler,
  157. post_process_class,
  158. eval_class,
  159. pre_best_model_dict,
  160. logger,
  161. vdl_writer,
  162. )
  163. if __name__ == "__main__":
  164. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  165. main(config, device, logger, vdl_writer)