train.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.append(__dir__)
  21. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
  22. import yaml
  23. import paddle
  24. import paddle.distributed as dist
  25. from ppocr.data import build_dataloader, set_signal_handlers
  26. from ppocr.modeling.architectures import build_model
  27. from ppocr.losses import build_loss
  28. from ppocr.optimizer import build_optimizer
  29. from ppocr.postprocess import build_post_process
  30. from ppocr.metrics import build_metric
  31. from ppocr.utils.save_load import load_model
  32. from ppocr.utils.utility import set_seed
  33. from ppocr.modeling.architectures import apply_to_static
  34. import tools.program as program
  35. import tools.naive_sync_bn as naive_sync_bn
  36. dist.get_world_size()
  37. def main(config, device, logger, vdl_writer, seed):
  38. # init dist environment
  39. if config["Global"]["distributed"]:
  40. dist.init_parallel_env()
  41. global_config = config["Global"]
  42. # build dataloader
  43. set_signal_handlers()
  44. train_dataloader = build_dataloader(config, "Train", device, logger, seed)
  45. if len(train_dataloader) == 0:
  46. logger.error(
  47. "No Images in train dataset, please ensure\n"
  48. + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
  49. + "\t2. The annotation file and path in the configuration file are provided normally."
  50. )
  51. return
  52. if config["Eval"]:
  53. valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
  54. else:
  55. valid_dataloader = None
  56. step_pre_epoch = len(train_dataloader)
  57. # build post process
  58. post_process_class = build_post_process(config["PostProcess"], global_config)
  59. # build model
  60. # for rec algorithm
  61. if hasattr(post_process_class, "character"):
  62. char_num = len(getattr(post_process_class, "character"))
  63. if config["Architecture"]["algorithm"] in [
  64. "Distillation",
  65. ]: # distillation model
  66. for key in config["Architecture"]["Models"]:
  67. if (
  68. config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
  69. ): # for multi head
  70. if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
  71. char_num = char_num - 2
  72. if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
  73. char_num = char_num - 3
  74. out_channels_list = {}
  75. out_channels_list["CTCLabelDecode"] = char_num
  76. # update SARLoss params
  77. if (
  78. list(config["Loss"]["loss_config_list"][-1].keys())[0]
  79. == "DistillationSARLoss"
  80. ):
  81. config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
  82. "ignore_index"
  83. ] = (char_num + 1)
  84. out_channels_list["SARLabelDecode"] = char_num + 2
  85. elif any(
  86. "DistillationNRTRLoss" in d
  87. for d in config["Loss"]["loss_config_list"]
  88. ):
  89. out_channels_list["NRTRLabelDecode"] = char_num + 3
  90. config["Architecture"]["Models"][key]["Head"][
  91. "out_channels_list"
  92. ] = out_channels_list
  93. else:
  94. config["Architecture"]["Models"][key]["Head"][
  95. "out_channels"
  96. ] = char_num
  97. elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
  98. if config["PostProcess"]["name"] == "SARLabelDecode":
  99. char_num = char_num - 2
  100. if config["PostProcess"]["name"] == "NRTRLabelDecode":
  101. char_num = char_num - 3
  102. out_channels_list = {}
  103. out_channels_list["CTCLabelDecode"] = char_num
  104. # update SARLoss params
  105. if list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss":
  106. if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
  107. config["Loss"]["loss_config_list"][1]["SARLoss"] = {
  108. "ignore_index": char_num + 1
  109. }
  110. else:
  111. config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
  112. char_num + 1
  113. )
  114. out_channels_list["SARLabelDecode"] = char_num + 2
  115. elif list(config["Loss"]["loss_config_list"][1].keys())[0] == "NRTRLoss":
  116. out_channels_list["NRTRLabelDecode"] = char_num + 3
  117. config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
  118. else: # base rec model
  119. config["Architecture"]["Head"]["out_channels"] = char_num
  120. if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
  121. config["Loss"]["ignore_index"] = char_num - 1
  122. model = build_model(config["Architecture"])
  123. use_sync_bn = config["Global"].get("use_sync_bn", False)
  124. if use_sync_bn:
  125. if config["Global"].get("use_npu", False) or config["Global"].get(
  126. "use_xpu", False
  127. ):
  128. naive_sync_bn.convert_syncbn(model)
  129. else:
  130. model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  131. logger.info("convert_sync_batchnorm")
  132. model = apply_to_static(model, config, logger)
  133. # build loss
  134. loss_class = build_loss(config["Loss"])
  135. # build optim
  136. optimizer, lr_scheduler = build_optimizer(
  137. config["Optimizer"],
  138. epochs=config["Global"]["epoch_num"],
  139. step_each_epoch=len(train_dataloader),
  140. model=model,
  141. )
  142. # build metric
  143. eval_class = build_metric(config["Metric"])
  144. logger.info("train dataloader has {} iters".format(len(train_dataloader)))
  145. if valid_dataloader is not None:
  146. logger.info("valid dataloader has {} iters".format(len(valid_dataloader)))
  147. use_amp = config["Global"].get("use_amp", False)
  148. amp_level = config["Global"].get("amp_level", "O2")
  149. amp_dtype = config["Global"].get("amp_dtype", "float16")
  150. amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
  151. amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
  152. if os.path.exists(
  153. os.path.join(config["Global"]["save_model_dir"], "train_result.json")
  154. ):
  155. try:
  156. os.remove(
  157. os.path.join(config["Global"]["save_model_dir"], "train_result.json")
  158. )
  159. except:
  160. pass
  161. if use_amp:
  162. AMP_RELATED_FLAGS_SETTING = {}
  163. if paddle.is_compiled_with_cuda():
  164. AMP_RELATED_FLAGS_SETTING.update(
  165. {
  166. "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
  167. "FLAGS_gemm_use_half_precision_compute_type": 0,
  168. }
  169. )
  170. paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
  171. scale_loss = config["Global"].get("scale_loss", 1.0)
  172. use_dynamic_loss_scaling = config["Global"].get(
  173. "use_dynamic_loss_scaling", False
  174. )
  175. scaler = paddle.amp.GradScaler(
  176. init_loss_scaling=scale_loss,
  177. use_dynamic_loss_scaling=use_dynamic_loss_scaling,
  178. )
  179. if amp_level == "O2":
  180. model, optimizer = paddle.amp.decorate(
  181. models=model,
  182. optimizers=optimizer,
  183. level=amp_level,
  184. master_weight=True,
  185. dtype=amp_dtype,
  186. )
  187. else:
  188. scaler = None
  189. # load pretrain model
  190. pre_best_model_dict = load_model(
  191. config, model, optimizer, config["Architecture"]["model_type"]
  192. )
  193. if config["Global"]["distributed"]:
  194. find_unused_parameters = config["Global"].get("find_unused_parameters", False)
  195. model = paddle.DataParallel(
  196. model, find_unused_parameters=find_unused_parameters
  197. )
  198. # start train
  199. program.train(
  200. config,
  201. train_dataloader,
  202. valid_dataloader,
  203. device,
  204. model,
  205. loss_class,
  206. optimizer,
  207. lr_scheduler,
  208. post_process_class,
  209. eval_class,
  210. pre_best_model_dict,
  211. logger,
  212. step_pre_epoch,
  213. vdl_writer,
  214. scaler,
  215. amp_level,
  216. amp_custom_black_list,
  217. amp_custom_white_list,
  218. amp_dtype,
  219. )
  220. def test_reader(config, device, logger):
  221. loader = build_dataloader(config, "Train", device, logger)
  222. import time
  223. starttime = time.time()
  224. count = 0
  225. try:
  226. for data in loader():
  227. count += 1
  228. if count % 1 == 0:
  229. batch_time = time.time() - starttime
  230. starttime = time.time()
  231. logger.info(
  232. "reader: {}, {}, {}".format(count, len(data[0]), batch_time)
  233. )
  234. except Exception as e:
  235. logger.info(e)
  236. logger.info("finish reader: {}, Success!".format(count))
  237. if __name__ == "__main__":
  238. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  239. seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
  240. set_seed(seed)
  241. main(config, device, logger, vdl_writer, seed)
  242. # test_reader(config, device, logger)