| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Any, Dict
- import numpy as np
- import torch
- import torch.nn.functional as F
- from modelscope.metainfo import Models
- from modelscope.models.base.base_torch_model import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.config import Config
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- from .modules.dbnet import DBModel, DBNasModel, VLPTModel
- from .utils import boxes_from_bitmap, polygons_from_bitmap
- LOGGER = get_logger()
- @MODELS.register_module(Tasks.ocr_detection, module_name=Models.ocr_detection)
- class OCRDetection(TorchModel):
- def __init__(self, model_dir: str, **kwargs):
- """initialize the ocr recognition model from the `model_dir` path.
- Args:
- model_dir (str): the model path.
- """
- super().__init__(model_dir, **kwargs)
- model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- cfgs = Config.from_file(
- os.path.join(model_dir, ModelFile.CONFIGURATION))
- self.thresh = cfgs.model.inference_kwargs.thresh
- self.return_polygon = cfgs.model.inference_kwargs.return_polygon
- self.backbone = cfgs.model.backbone
- self.detector = None
- self.onnx_export = False
- if self.backbone == 'resnet50':
- self.detector = VLPTModel()
- elif self.backbone == 'resnet18':
- self.detector = DBModel()
- elif self.backbone == 'proxylessnas':
- self.detector = DBNasModel()
- else:
- raise TypeError(
- f'detector backbone should be either resnet18, resnet50, but got {cfgs.model.backbone}'
- )
- if model_path != '':
- self.detector.load_state_dict(
- torch.load(model_path, map_location='cpu', weights_only=True),
- strict=False)
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
- """
- Args:
- img (`torch.Tensor`): image tensor,
- shape of each tensor is [3, H, W].
- Return:
- results (`torch.Tensor`): bitmap tensor,
- shape of each tensor is [1, H, W].
- org_shape (`List`): image original shape,
- value is [height, width].
- """
- if type(input) is dict:
- pred = self.detector(input['img'])
- else:
- # for onnx convert
- input = {'img': input, 'org_shape': [800, 800]}
- pred = self.detector(input['img'])
- return {'results': pred, 'org_shape': input['org_shape']}
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- pred = inputs['results'][0]
- if self.onnx_export:
- return pred
- height, width = inputs['org_shape']
- segmentation = pred > self.thresh
- if self.return_polygon:
- boxes, scores = polygons_from_bitmap(pred, segmentation, width,
- height)
- else:
- boxes, scores = boxes_from_bitmap(pred, segmentation, width,
- height)
- result = {'det_polygons': np.array(boxes)}
- return result
|