export_model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import json
  17. import copy
  18. import shutil
  19. import paddle
  20. import paddle.nn as nn
  21. from paddle.jit import to_static
  22. from collections import OrderedDict
  23. from packaging import version
  24. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  25. from ppocr.modeling.architectures import build_model
  26. from ppocr.postprocess import build_post_process
  27. from ppocr.utils.save_load import load_model
  28. from ppocr.utils.logging import get_logger
  29. def represent_dictionary_order(self, dict_data):
  30. return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())
  31. def setup_orderdict():
  32. yaml.add_representer(OrderedDict, represent_dictionary_order)
  33. def dump_infer_config(config, path, logger):
  34. setup_orderdict()
  35. infer_cfg = OrderedDict()
  36. if not os.path.exists(os.path.dirname(path)):
  37. os.makedirs(os.path.dirname(path))
  38. model_name = None
  39. if config["Global"].get("model_name", None):
  40. model_name = config["Global"]["model_name"]
  41. infer_cfg["Global"] = {"model_name": model_name}
  42. if config["Global"].get("uniform_output_enabled", True):
  43. arch_config = config["Architecture"]
  44. if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
  45. common_dynamic_shapes = {
  46. "x": [[1, 3, 48, 160], [1, 3, 48, 320], [8, 3, 48, 3200]]
  47. }
  48. elif arch_config["model_type"] == "det":
  49. common_dynamic_shapes = {
  50. "x": [[1, 3, 32, 32], [1, 3, 736, 736], [1, 3, 4000, 4000]]
  51. }
  52. elif arch_config["algorithm"] == "SLANet":
  53. if model_name == "SLANet_plus":
  54. common_dynamic_shapes = {
  55. "x": [[1, 3, 32, 32], [1, 3, 64, 448], [1, 3, 488, 488]]
  56. }
  57. else:
  58. common_dynamic_shapes = {
  59. "x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]]
  60. }
  61. elif arch_config["algorithm"] == "SLANeXt":
  62. common_dynamic_shapes = {
  63. "x": [[1, 3, 512, 512], [1, 3, 512, 512], [1, 3, 512, 512]]
  64. }
  65. elif arch_config["algorithm"] == "LaTeXOCR":
  66. common_dynamic_shapes = {
  67. "x": [[1, 1, 32, 32], [1, 1, 64, 448], [1, 1, 192, 672]]
  68. }
  69. elif arch_config["algorithm"] == "UniMERNet":
  70. common_dynamic_shapes = {
  71. "x": [[1, 1, 192, 672], [1, 1, 192, 672], [8, 1, 192, 672]]
  72. }
  73. elif arch_config["algorithm"] in ["PP-FormulaNet-L", "PP-FormulaNet_plus-L"]:
  74. common_dynamic_shapes = {
  75. "x": [[1, 1, 768, 768], [1, 1, 768, 768], [8, 1, 768, 768]]
  76. }
  77. elif arch_config["algorithm"] in [
  78. "PP-FormulaNet-S",
  79. "PP-FormulaNet_plus-S",
  80. "PP-FormulaNet_plus-M",
  81. ]:
  82. common_dynamic_shapes = {
  83. "x": [[1, 1, 384, 384], [1, 1, 384, 384], [8, 1, 384, 384]]
  84. }
  85. else:
  86. common_dynamic_shapes = None
  87. backend_keys = ["paddle_infer", "tensorrt"]
  88. hpi_config = {
  89. "backend_configs": {
  90. key: {
  91. (
  92. "dynamic_shapes" if key == "tensorrt" else "trt_dynamic_shapes"
  93. ): common_dynamic_shapes
  94. }
  95. for key in backend_keys
  96. }
  97. }
  98. if common_dynamic_shapes:
  99. infer_cfg["Hpi"] = hpi_config
  100. infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
  101. postprocess = OrderedDict()
  102. for k, v in config["PostProcess"].items():
  103. if config["Architecture"].get("algorithm") in [
  104. "LaTeXOCR",
  105. "UniMERNet",
  106. "PP-FormulaNet-L",
  107. "PP-FormulaNet-S",
  108. "PP-FormulaNet_plus-L",
  109. "PP-FormulaNet_plus-M",
  110. "PP-FormulaNet_plus-S",
  111. ]:
  112. if k != "rec_char_dict_path":
  113. postprocess[k] = v
  114. else:
  115. postprocess[k] = v
  116. if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
  117. tokenizer_file = config["Global"].get("rec_char_dict_path")
  118. if tokenizer_file is not None:
  119. with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
  120. character_dict = json.load(tokenizer_config_handle)
  121. postprocess["character_dict"] = character_dict
  122. elif config["Architecture"].get("algorithm") in [
  123. "UniMERNet",
  124. "PP-FormulaNet-L",
  125. "PP-FormulaNet-S",
  126. "PP-FormulaNet_plus-L",
  127. "PP-FormulaNet_plus-M",
  128. "PP-FormulaNet_plus-S",
  129. ]:
  130. tokenizer_file = config["Global"].get("rec_char_dict_path")
  131. fast_tokenizer_file = os.path.join(tokenizer_file, "tokenizer.json")
  132. tokenizer_config_file = os.path.join(tokenizer_file, "tokenizer_config.json")
  133. postprocess["character_dict"] = {}
  134. if fast_tokenizer_file is not None:
  135. with open(fast_tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
  136. character_dict = json.load(tokenizer_config_handle)
  137. postprocess["character_dict"]["fast_tokenizer_file"] = character_dict
  138. if tokenizer_config_file is not None:
  139. with open(
  140. tokenizer_config_file, encoding="utf-8"
  141. ) as tokenizer_config_handle:
  142. character_dict = json.load(tokenizer_config_handle)
  143. postprocess["character_dict"]["tokenizer_config_file"] = character_dict
  144. else:
  145. if config["Global"].get("character_dict_path") is not None:
  146. with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
  147. lines = f.readlines()
  148. character_dict = [line.strip("\n") for line in lines]
  149. postprocess["character_dict"] = character_dict
  150. infer_cfg["PostProcess"] = postprocess
  151. with open(path, "w", encoding="utf-8") as f:
  152. yaml.dump(infer_cfg, f, default_flow_style=False, allow_unicode=True)
  153. logger.info("Export inference config file to {}".format(os.path.join(path)))
  154. def dynamic_to_static(model, arch_config, logger, input_shape=None):
  155. if arch_config["algorithm"] == "SRN":
  156. max_text_length = arch_config["Head"]["max_text_length"]
  157. other_shape = [
  158. paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"),
  159. [
  160. paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"),
  161. paddle.static.InputSpec(
  162. shape=[None, max_text_length, 1], dtype="int64"
  163. ),
  164. paddle.static.InputSpec(
  165. shape=[None, 8, max_text_length, max_text_length], dtype="int64"
  166. ),
  167. paddle.static.InputSpec(
  168. shape=[None, 8, max_text_length, max_text_length], dtype="int64"
  169. ),
  170. ],
  171. ]
  172. model = to_static(model, input_spec=other_shape)
  173. elif arch_config["algorithm"] == "SAR":
  174. other_shape = [
  175. paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
  176. [paddle.static.InputSpec(shape=[None], dtype="float32")],
  177. ]
  178. model = to_static(model, input_spec=other_shape)
  179. elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
  180. other_shape = [
  181. paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"),
  182. ]
  183. model = to_static(model, input_spec=other_shape)
  184. elif arch_config["algorithm"] in ["SVTR", "CPPD"]:
  185. other_shape = [
  186. paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
  187. ]
  188. model = to_static(model, input_spec=other_shape)
  189. elif arch_config["algorithm"] == "PREN":
  190. other_shape = [
  191. paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
  192. ]
  193. model = to_static(model, input_spec=other_shape)
  194. elif arch_config["model_type"] == "sr":
  195. other_shape = [
  196. paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32")
  197. ]
  198. model = to_static(model, input_spec=other_shape)
  199. elif arch_config["algorithm"] == "ViTSTR":
  200. other_shape = [
  201. paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"),
  202. ]
  203. model = to_static(model, input_spec=other_shape)
  204. elif arch_config["algorithm"] == "ABINet":
  205. if not input_shape:
  206. input_shape = [3, 32, 128]
  207. other_shape = [
  208. paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"),
  209. ]
  210. model = to_static(model, input_spec=other_shape)
  211. elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]:
  212. other_shape = [
  213. paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"),
  214. ]
  215. model = to_static(model, input_spec=other_shape)
  216. elif arch_config["algorithm"] in ["SATRN"]:
  217. other_shape = [
  218. paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"),
  219. ]
  220. model = to_static(model, input_spec=other_shape)
  221. elif arch_config["algorithm"] == "VisionLAN":
  222. other_shape = [
  223. paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"),
  224. ]
  225. model = to_static(model, input_spec=other_shape)
  226. elif arch_config["algorithm"] == "RobustScanner":
  227. max_text_length = arch_config["Head"]["max_text_length"]
  228. other_shape = [
  229. paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"),
  230. [
  231. paddle.static.InputSpec(
  232. shape=[
  233. None,
  234. ],
  235. dtype="float32",
  236. ),
  237. paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"),
  238. ],
  239. ]
  240. model = to_static(model, input_spec=other_shape)
  241. elif arch_config["algorithm"] == "CAN":
  242. other_shape = [
  243. [
  244. paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
  245. paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
  246. paddle.static.InputSpec(
  247. shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64"
  248. ),
  249. ]
  250. ]
  251. model = to_static(model, input_spec=other_shape)
  252. elif arch_config["algorithm"] == "LaTeXOCR":
  253. other_shape = [
  254. paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
  255. ]
  256. model = to_static(model, input_spec=other_shape)
  257. elif arch_config["algorithm"] == "UniMERNet":
  258. model = paddle.jit.to_static(
  259. model,
  260. input_spec=[
  261. paddle.static.InputSpec(shape=[-1, 1, 192, 672], dtype="float32")
  262. ],
  263. full_graph=True,
  264. )
  265. elif arch_config["algorithm"] == "SLANeXt":
  266. model = paddle.jit.to_static(
  267. model,
  268. input_spec=[
  269. paddle.static.InputSpec(shape=[-1, 3, 512, 512], dtype="float32")
  270. ],
  271. full_graph=True,
  272. )
  273. elif arch_config["algorithm"] in ["PP-FormulaNet-L", "PP-FormulaNet_plus-L"]:
  274. model = paddle.jit.to_static(
  275. model,
  276. input_spec=[
  277. paddle.static.InputSpec(shape=[-1, 1, 768, 768], dtype="float32")
  278. ],
  279. full_graph=True,
  280. )
  281. elif arch_config["algorithm"] in [
  282. "PP-FormulaNet-S",
  283. "PP-FormulaNet_plus-S",
  284. "PP-FormulaNet_plus-M",
  285. ]:
  286. model = paddle.jit.to_static(
  287. model,
  288. input_spec=[
  289. paddle.static.InputSpec(shape=[-1, 1, 384, 384], dtype="float32")
  290. ],
  291. full_graph=True,
  292. )
  293. elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
  294. input_spec = [
  295. paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
  296. paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
  297. paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
  298. paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
  299. paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image
  300. ]
  301. if "Re" in arch_config["Backbone"]["name"]:
  302. input_spec.extend(
  303. [
  304. paddle.static.InputSpec(
  305. shape=[None, 512, 3], dtype="int64"
  306. ), # entities
  307. paddle.static.InputSpec(
  308. shape=[None, None, 2], dtype="int64"
  309. ), # relations
  310. ]
  311. )
  312. if model.backbone.use_visual_backbone is False:
  313. input_spec.pop(4)
  314. model = to_static(model, input_spec=[input_spec])
  315. else:
  316. infer_shape = [3, -1, -1]
  317. if arch_config["model_type"] == "rec":
  318. infer_shape = [3, 32, -1] # for rec model, H must be 32
  319. if (
  320. "Transform" in arch_config
  321. and arch_config["Transform"] is not None
  322. and arch_config["Transform"]["name"] == "TPS"
  323. ):
  324. logger.info(
  325. "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
  326. )
  327. infer_shape[-1] = 100
  328. elif arch_config["model_type"] == "table":
  329. infer_shape = [3, 488, 488]
  330. if arch_config["algorithm"] == "TableMaster":
  331. infer_shape = [3, 480, 480]
  332. if arch_config["algorithm"] == "SLANet":
  333. infer_shape = [3, -1, -1]
  334. model = to_static(
  335. model,
  336. input_spec=[
  337. paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
  338. ],
  339. )
  340. if (
  341. arch_config["model_type"] != "sr"
  342. and arch_config["Backbone"]["name"] == "PPLCNetV3"
  343. ):
  344. # for rep lcnetv3
  345. for layer in model.sublayers():
  346. if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
  347. layer.rep()
  348. return model
  349. def export_single_model(
  350. model,
  351. arch_config,
  352. save_path,
  353. logger,
  354. yaml_path,
  355. config,
  356. input_shape=None,
  357. quanter=None,
  358. ):
  359. model = dynamic_to_static(model, arch_config, logger, input_shape)
  360. if quanter is None:
  361. try:
  362. import encryption # Attempt to import the encryption module for AIStudio's encryption model
  363. except (
  364. ModuleNotFoundError
  365. ): # Encryption is not needed if the module cannot be imported
  366. print("Skipping import of the encryption module")
  367. paddle_version = version.parse(paddle.__version__)
  368. if config["Global"].get("export_with_pir", True):
  369. assert (
  370. paddle_version >= version.parse("3.0.0b2")
  371. or paddle_version == version.parse("0.0.0")
  372. ) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
  373. paddle.jit.save(model, save_path)
  374. else:
  375. if paddle_version >= version.parse(
  376. "3.0.0b2"
  377. ) or paddle_version == version.parse("0.0.0"):
  378. model.forward.rollback()
  379. with paddle.pir_utils.OldIrGuard():
  380. model = dynamic_to_static(model, arch_config, logger, input_shape)
  381. paddle.jit.save(model, save_path)
  382. else:
  383. paddle.jit.save(model, save_path)
  384. else:
  385. quanter.save_quantized_model(model, save_path)
  386. logger.info("inference model is saved to {}".format(save_path))
  387. return
  388. def convert_bn(model):
  389. for n, m in model.named_children():
  390. if isinstance(m, nn.SyncBatchNorm):
  391. bn = nn.BatchNorm2D(
  392. m._num_features, m._momentum, m._epsilon, m._weight_attr, m._bias_attr
  393. )
  394. bn.set_dict(m.state_dict())
  395. setattr(model, n, bn)
  396. else:
  397. convert_bn(m)
  398. def export(config, base_model=None, save_path=None):
  399. if paddle.distributed.get_rank() != 0:
  400. return
  401. logger = get_logger()
  402. # build post process
  403. post_process_class = build_post_process(config["PostProcess"], config["Global"])
  404. # build model
  405. # for rec algorithm
  406. if hasattr(post_process_class, "character"):
  407. char_num = len(getattr(post_process_class, "character"))
  408. if config["Architecture"]["algorithm"] in [
  409. "Distillation",
  410. ]: # distillation model
  411. for key in config["Architecture"]["Models"]:
  412. if (
  413. config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
  414. ): # multi head
  415. out_channels_list = {}
  416. if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
  417. char_num = char_num - 2
  418. if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
  419. char_num = char_num - 3
  420. out_channels_list["CTCLabelDecode"] = char_num
  421. out_channels_list["SARLabelDecode"] = char_num + 2
  422. out_channels_list["NRTRLabelDecode"] = char_num + 3
  423. config["Architecture"]["Models"][key]["Head"][
  424. "out_channels_list"
  425. ] = out_channels_list
  426. else:
  427. config["Architecture"]["Models"][key]["Head"][
  428. "out_channels"
  429. ] = char_num
  430. # just one final tensor needs to exported for inference
  431. config["Architecture"]["Models"][key]["return_all_feats"] = False
  432. elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
  433. out_channels_list = {}
  434. char_num = len(getattr(post_process_class, "character"))
  435. if config["PostProcess"]["name"] == "SARLabelDecode":
  436. char_num = char_num - 2
  437. if config["PostProcess"]["name"] == "NRTRLabelDecode":
  438. char_num = char_num - 3
  439. out_channels_list["CTCLabelDecode"] = char_num
  440. out_channels_list["SARLabelDecode"] = char_num + 2
  441. out_channels_list["NRTRLabelDecode"] = char_num + 3
  442. config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
  443. else: # base rec model
  444. config["Architecture"]["Head"]["out_channels"] = char_num
  445. # for sr algorithm
  446. if config["Architecture"]["model_type"] == "sr":
  447. config["Architecture"]["Transform"]["infer_mode"] = True
  448. # for latexocr algorithm
  449. if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
  450. config["Architecture"]["Backbone"]["is_predict"] = True
  451. config["Architecture"]["Backbone"]["is_export"] = True
  452. config["Architecture"]["Head"]["is_export"] = True
  453. if config["Architecture"].get("algorithm") in ["UniMERNet"]:
  454. config["Architecture"]["Backbone"]["is_export"] = True
  455. config["Architecture"]["Head"]["is_export"] = True
  456. if config["Architecture"].get("algorithm") in [
  457. "PP-FormulaNet-S",
  458. "PP-FormulaNet-L",
  459. "PP-FormulaNet_plus-S",
  460. "PP-FormulaNet_plus-M",
  461. "PP-FormulaNet_plus-L",
  462. ]:
  463. config["Architecture"]["Head"]["is_export"] = True
  464. if base_model is not None:
  465. model = base_model
  466. if isinstance(model, paddle.DataParallel):
  467. model = copy.deepcopy(model._layers)
  468. else:
  469. model = copy.deepcopy(model)
  470. else:
  471. model = build_model(config["Architecture"])
  472. load_model(config, model, model_type=config["Architecture"]["model_type"])
  473. convert_bn(model)
  474. model.eval()
  475. if not save_path:
  476. save_path = config["Global"]["save_inference_dir"]
  477. yaml_path = os.path.join(save_path, "inference.yml")
  478. arch_config = config["Architecture"]
  479. if (
  480. arch_config["algorithm"] in ["SVTR", "CPPD"]
  481. and arch_config["Head"]["name"] != "MultiHead"
  482. ):
  483. input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
  484. "image_shape"
  485. ]
  486. elif arch_config["algorithm"].lower() == "ABINet".lower():
  487. rec_rs = [
  488. c
  489. for c in config["Eval"]["dataset"]["transforms"]
  490. if "ABINetRecResizeImg" in c
  491. ]
  492. input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
  493. else:
  494. input_shape = None
  495. dump_infer_config(config, yaml_path, logger)
  496. if arch_config["algorithm"] in [
  497. "Distillation",
  498. ]: # distillation model
  499. archs = list(arch_config["Models"].values())
  500. for idx, name in enumerate(model.model_name_list):
  501. sub_model_save_path = os.path.join(save_path, name, "inference")
  502. export_single_model(
  503. model.model_list[idx],
  504. archs[idx],
  505. sub_model_save_path,
  506. logger,
  507. yaml_path,
  508. config,
  509. )
  510. else:
  511. save_path = os.path.join(save_path, "inference")
  512. export_single_model(
  513. model,
  514. arch_config,
  515. save_path,
  516. logger,
  517. yaml_path,
  518. config,
  519. input_shape=input_shape,
  520. )