__init__.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = ["build_backbone"]
  15. def build_backbone(config, model_type):
  16. if model_type == "det" or model_type == "table":
  17. from .det_mobilenet_v3 import MobileNetV3
  18. from .det_resnet import ResNet
  19. from .det_resnet_vd import ResNet_vd
  20. from .det_resnet_vd_sast import ResNet_SAST
  21. from .det_pp_lcnet import PPLCNet
  22. from .rec_lcnetv3 import PPLCNetV3
  23. from .rec_hgnet import PPHGNet_small
  24. from .rec_vit import ViT
  25. from .det_pp_lcnet_v2 import PPLCNetV2_base
  26. from .rec_repvit import RepSVTR_det
  27. from .rec_vary_vit import Vary_VIT_B
  28. from .rec_pphgnetv2 import PPHGNetV2_B4
  29. support_dict = [
  30. "MobileNetV3",
  31. "ResNet",
  32. "ResNet_vd",
  33. "ResNet_SAST",
  34. "PPLCNet",
  35. "PPLCNetV3",
  36. "PPHGNet_small",
  37. "PPLCNetV2_base",
  38. "RepSVTR_det",
  39. "Vary_VIT_B",
  40. "PPHGNetV2_B4",
  41. ]
  42. if model_type == "table":
  43. from .table_master_resnet import TableResNetExtra
  44. support_dict.append("TableResNetExtra")
  45. elif model_type == "rec" or model_type == "cls":
  46. from .rec_mobilenet_v3 import MobileNetV3
  47. from .rec_resnet_vd import ResNet
  48. from .rec_resnet_fpn import ResNetFPN
  49. from .rec_mv1_enhance import MobileNetV1Enhance
  50. from .rec_nrtr_mtb import MTB
  51. from .rec_resnet_31 import ResNet31
  52. from .rec_resnet_32 import ResNet32
  53. from .rec_resnet_45 import ResNet45
  54. from .rec_resnet_aster import ResNet_ASTER
  55. from .rec_micronet import MicroNet
  56. from .rec_efficientb3_pren import EfficientNetb3_PREN
  57. from .rec_svtrnet import SVTRNet
  58. from .rec_vitstr import ViTSTR
  59. from .rec_resnet_rfl import ResNetRFL
  60. from .rec_densenet import DenseNet
  61. from .rec_resnetv2 import ResNetV2
  62. from .rec_hybridvit import HybridTransformer
  63. from .rec_donut_swin import DonutSwinModel
  64. from .rec_shallow_cnn import ShallowCNN
  65. from .rec_lcnetv3 import PPLCNetV3
  66. from .rec_hgnet import PPHGNet_small
  67. from .rec_vit_parseq import ViTParseQ
  68. from .rec_repvit import RepSVTR
  69. from .rec_svtrv2 import SVTRv2
  70. from .rec_vary_vit import Vary_VIT_B, Vary_VIT_B_Formula
  71. from .rec_pphgnetv2 import (
  72. PPHGNetV2_B4,
  73. PPHGNetV2_B4_Formula,
  74. PPHGNetV2_B6_Formula,
  75. )
  76. support_dict = [
  77. "MobileNetV1Enhance",
  78. "MobileNetV3",
  79. "ResNet",
  80. "ResNetFPN",
  81. "MTB",
  82. "ResNet31",
  83. "ResNet45",
  84. "ResNet_ASTER",
  85. "MicroNet",
  86. "EfficientNetb3_PREN",
  87. "SVTRNet",
  88. "ViTSTR",
  89. "ResNet32",
  90. "ResNetRFL",
  91. "DenseNet",
  92. "ShallowCNN",
  93. "PPLCNetV3",
  94. "PPHGNet_small",
  95. "ViTParseQ",
  96. "ViT",
  97. "RepSVTR",
  98. "SVTRv2",
  99. "ResNetV2",
  100. "HybridTransformer",
  101. "DonutSwinModel",
  102. "Vary_VIT_B",
  103. "PPHGNetV2_B4",
  104. "PPHGNetV2_B4_Formula",
  105. "PPHGNetV2_B6_Formula",
  106. "Vary_VIT_B_Formula",
  107. ]
  108. elif model_type == "e2e":
  109. from .e2e_resnet_vd_pg import ResNet
  110. support_dict = ["ResNet"]
  111. elif model_type == "kie":
  112. from .kie_unet_sdmgr import Kie_backbone
  113. from .vqa_layoutlm import (
  114. LayoutLMForSer,
  115. LayoutLMv2ForSer,
  116. LayoutLMv2ForRe,
  117. LayoutXLMForSer,
  118. LayoutXLMForRe,
  119. )
  120. support_dict = [
  121. "Kie_backbone",
  122. "LayoutLMForSer",
  123. "LayoutLMv2ForSer",
  124. "LayoutLMv2ForRe",
  125. "LayoutXLMForSer",
  126. "LayoutXLMForRe",
  127. ]
  128. elif model_type == "table":
  129. from .table_resnet_vd import ResNet
  130. from .table_mobilenet_v3 import MobileNetV3
  131. from .rec_vary_vit import Vary_VIT_B
  132. support_dict = ["ResNet", "MobileNetV3", "Vary_VIT_B"]
  133. else:
  134. raise NotImplementedError
  135. module_name = config.pop("name")
  136. assert module_name in support_dict, Exception(
  137. "when model typs is {}, backbone only support {}".format(
  138. model_type, support_dict
  139. )
  140. )
  141. module_class = eval(module_name)(**config)
  142. return module_class