__init__.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import importlib
  16. from paddle.jit import to_static
  17. from paddle.static import InputSpec
  18. from .base_model import BaseModel
  19. from .distillation_model import DistillationModel
  20. __all__ = ["build_model", "apply_to_static"]
  21. def build_model(config):
  22. config = copy.deepcopy(config)
  23. if not "name" in config:
  24. arch = BaseModel(config)
  25. else:
  26. name = config.pop("name")
  27. mod = importlib.import_module(__name__)
  28. arch = getattr(mod, name)(config)
  29. return arch
  30. def apply_to_static(model, config, logger):
  31. if config["Global"].get("to_static", False) is not True:
  32. return model
  33. assert (
  34. "d2s_train_image_shape" in config["Global"]
  35. ), "d2s_train_image_shape must be assigned for static training mode..."
  36. supported_list = [
  37. "DB",
  38. "SVTR_LCNet",
  39. "TableMaster",
  40. "LayoutXLM",
  41. "SLANet",
  42. "SVTR",
  43. "SVTR_HGNet",
  44. "LaTeXOCR",
  45. "UniMERNet",
  46. "PP-FormulaNet-S",
  47. "PP-FormulaNet-L",
  48. ]
  49. if config["Architecture"]["algorithm"] in ["Distillation"]:
  50. algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
  51. else:
  52. algo = config["Architecture"]["algorithm"]
  53. assert (
  54. algo in supported_list
  55. ), f"algorithms that supports static training must in in {supported_list} but got {algo}"
  56. specs = [
  57. InputSpec([None] + config["Global"]["d2s_train_image_shape"], dtype="float32")
  58. ]
  59. if algo == "SVTR_LCNet":
  60. specs.append(
  61. [
  62. InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
  63. InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
  64. InputSpec([None], dtype="int64"),
  65. InputSpec([None], dtype="float64"),
  66. ]
  67. )
  68. elif algo == "TableMaster":
  69. specs.append(
  70. [
  71. InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
  72. InputSpec(
  73. [None, config["Global"]["max_text_length"], 4], dtype="float32"
  74. ),
  75. InputSpec(
  76. [None, config["Global"]["max_text_length"], 1], dtype="float32"
  77. ),
  78. InputSpec([None, 6], dtype="float32"),
  79. ]
  80. )
  81. elif algo == "LayoutXLM":
  82. specs = [
  83. [
  84. InputSpec(shape=[None, 512], dtype="int64"), # input_ids
  85. InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
  86. InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
  87. InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
  88. InputSpec(shape=[None, 3, 224, 224], dtype="float32"), # image
  89. InputSpec(shape=[None, 512], dtype="int64"), # label
  90. ]
  91. ]
  92. elif algo == "SLANet":
  93. specs.append(
  94. [
  95. InputSpec(
  96. [None, config["Global"]["max_text_length"] + 2], dtype="int64"
  97. ),
  98. InputSpec(
  99. [None, config["Global"]["max_text_length"] + 2, 4], dtype="float32"
  100. ),
  101. InputSpec(
  102. [None, config["Global"]["max_text_length"] + 2, 1], dtype="float32"
  103. ),
  104. InputSpec([None], dtype="int64"),
  105. InputSpec([None, 6], dtype="float64"),
  106. ]
  107. )
  108. elif algo == "SVTR":
  109. specs.append(
  110. [
  111. InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
  112. InputSpec([None], dtype="int64"),
  113. ]
  114. )
  115. elif algo == "LaTeXOCR":
  116. specs = [
  117. [
  118. InputSpec(shape=[None, 1, None, None], dtype="float32"),
  119. InputSpec(shape=[None, None], dtype="float32"),
  120. InputSpec(shape=[None, None], dtype="float32"),
  121. ]
  122. ]
  123. elif algo in ["UniMERNet", "PP-FormulaNet-S", "PP-FormulaNet-L"]:
  124. specs = [
  125. [
  126. InputSpec(
  127. [None] + config["Global"]["d2s_train_image_shape"], dtype="float32"
  128. ),
  129. InputSpec(shape=[None, None], dtype="float32"),
  130. InputSpec(shape=[None, None], dtype="float32"),
  131. ]
  132. ]
  133. model = to_static(model, input_spec=specs)
  134. logger.info("Successfully to apply @to_static with specs: {}".format(specs))
  135. return model