# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict import numpy as np import torch import torch.nn.functional as F from modelscope.metainfo import Models from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from .modules.dbnet import DBModel, DBNasModel, VLPTModel from .utils import boxes_from_bitmap, polygons_from_bitmap LOGGER = get_logger() @MODELS.register_module(Tasks.ocr_detection, module_name=Models.ocr_detection) class OCRDetection(TorchModel): def __init__(self, model_dir: str, **kwargs): """initialize the ocr recognition model from the `model_dir` path. Args: model_dir (str): the model path. """ super().__init__(model_dir, **kwargs) model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) cfgs = Config.from_file( os.path.join(model_dir, ModelFile.CONFIGURATION)) self.thresh = cfgs.model.inference_kwargs.thresh self.return_polygon = cfgs.model.inference_kwargs.return_polygon self.backbone = cfgs.model.backbone self.detector = None self.onnx_export = False if self.backbone == 'resnet50': self.detector = VLPTModel() elif self.backbone == 'resnet18': self.detector = DBModel() elif self.backbone == 'proxylessnas': self.detector = DBNasModel() else: raise TypeError( f'detector backbone should be either resnet18, resnet50, but got {cfgs.model.backbone}' ) if model_path != '': self.detector.load_state_dict( torch.load(model_path, map_location='cpu', weights_only=True), strict=False) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: """ Args: img (`torch.Tensor`): image tensor, shape of each tensor is [3, H, W]. Return: results (`torch.Tensor`): bitmap tensor, shape of each tensor is [1, H, W]. org_shape (`List`): image original shape, value is [height, width]. """ if type(input) is dict: pred = self.detector(input['img']) else: # for onnx convert input = {'img': input, 'org_shape': [800, 800]} pred = self.detector(input['img']) return {'results': pred, 'org_shape': input['org_shape']} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: pred = inputs['results'][0] if self.onnx_export: return pred height, width = inputs['org_shape'] segmentation = pred > self.thresh if self.return_polygon: boxes, scores = polygons_from_bitmap(pred, segmentation, width, height) else: boxes, scores = boxes_from_bitmap(pred, segmentation, width, height) result = {'det_polygons': np.array(boxes)} return result