| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- __all__ = ["build_head"]
- def build_head(config):
- # det head
- from .det_db_head import DBHead, PFHeadLocal
- from .det_east_head import EASTHead
- from .det_sast_head import SASTHead
- from .det_pse_head import PSEHead
- from .det_fce_head import FCEHead
- from .e2e_pg_head import PGHead
- from .det_ct_head import CT_Head
- # rec head
- from .rec_ctc_head import CTCHead
- from .rec_att_head import AttentionHead
- from .rec_srn_head import SRNHead
- from .rec_nrtr_head import Transformer
- from .rec_sar_head import SARHead
- from .rec_aster_head import AsterHead
- from .rec_pren_head import PRENHead
- from .rec_multi_head import MultiHead
- from .rec_spin_att_head import SPINAttentionHead
- from .rec_abinet_head import ABINetHead
- from .rec_robustscanner_head import RobustScannerHead
- from .rec_visionlan_head import VLHead
- from .rec_rfl_head import RFLHead
- from .rec_can_head import CANHead
- from .rec_latexocr_head import LaTeXOCRHead
- from .rec_satrn_head import SATRNHead
- from .rec_parseq_head import ParseQHead
- from .rec_cppd_head import CPPDHead
- from .rec_unimernet_head import UniMERNetHead
- from .rec_ppformulanet_head import PPFormulaNet_Head
- # cls head
- from .cls_head import ClsHead
- # kie head
- from .kie_sdmgr_head import SDMGRHead
- from .table_att_head import TableAttentionHead, SLAHead
- from .table_master_head import TableMasterHead
- support_dict = [
- "DBHead",
- "PSEHead",
- "FCEHead",
- "EASTHead",
- "SASTHead",
- "CTCHead",
- "ClsHead",
- "AttentionHead",
- "SRNHead",
- "PGHead",
- "Transformer",
- "TableAttentionHead",
- "SARHead",
- "AsterHead",
- "SDMGRHead",
- "PRENHead",
- "MultiHead",
- "ABINetHead",
- "TableMasterHead",
- "SPINAttentionHead",
- "VLHead",
- "SLAHead",
- "RobustScannerHead",
- "CT_Head",
- "RFLHead",
- "DRRGHead",
- "CANHead",
- "LaTeXOCRHead",
- "SATRNHead",
- "PFHeadLocal",
- "ParseQHead",
- "CPPDHead",
- "UniMERNetHead",
- "PPFormulaNet_Head",
- ]
- if config["name"] == "DRRGHead":
- from .det_drrg_head import DRRGHead
- support_dict.append("DRRGHead")
- # table head
- module_name = config.pop("name")
- assert module_name in support_dict, Exception(
- "head only support {}".format(support_dict)
- )
- module_class = eval(module_name)(**config)
- return module_class
|