model.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from modelscope.metainfo import Models
  8. from modelscope.models.base.base_torch_model import TorchModel
  9. from modelscope.models.builder import MODELS
  10. from modelscope.utils.config import Config
  11. from modelscope.utils.constant import ModelFile, Tasks
  12. from modelscope.utils.logger import get_logger
  13. from .modules.dbnet import DBModel, DBNasModel, VLPTModel
  14. from .utils import boxes_from_bitmap, polygons_from_bitmap
  15. LOGGER = get_logger()
  16. @MODELS.register_module(Tasks.ocr_detection, module_name=Models.ocr_detection)
  17. class OCRDetection(TorchModel):
  18. def __init__(self, model_dir: str, **kwargs):
  19. """initialize the ocr recognition model from the `model_dir` path.
  20. Args:
  21. model_dir (str): the model path.
  22. """
  23. super().__init__(model_dir, **kwargs)
  24. model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
  25. cfgs = Config.from_file(
  26. os.path.join(model_dir, ModelFile.CONFIGURATION))
  27. self.thresh = cfgs.model.inference_kwargs.thresh
  28. self.return_polygon = cfgs.model.inference_kwargs.return_polygon
  29. self.backbone = cfgs.model.backbone
  30. self.detector = None
  31. self.onnx_export = False
  32. if self.backbone == 'resnet50':
  33. self.detector = VLPTModel()
  34. elif self.backbone == 'resnet18':
  35. self.detector = DBModel()
  36. elif self.backbone == 'proxylessnas':
  37. self.detector = DBNasModel()
  38. else:
  39. raise TypeError(
  40. f'detector backbone should be either resnet18, resnet50, but got {cfgs.model.backbone}'
  41. )
  42. if model_path != '':
  43. self.detector.load_state_dict(
  44. torch.load(model_path, map_location='cpu', weights_only=True),
  45. strict=False)
  46. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  47. """
  48. Args:
  49. img (`torch.Tensor`): image tensor,
  50. shape of each tensor is [3, H, W].
  51. Return:
  52. results (`torch.Tensor`): bitmap tensor,
  53. shape of each tensor is [1, H, W].
  54. org_shape (`List`): image original shape,
  55. value is [height, width].
  56. """
  57. if type(input) is dict:
  58. pred = self.detector(input['img'])
  59. else:
  60. # for onnx convert
  61. input = {'img': input, 'org_shape': [800, 800]}
  62. pred = self.detector(input['img'])
  63. return {'results': pred, 'org_shape': input['org_shape']}
  64. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  65. pred = inputs['results'][0]
  66. if self.onnx_export:
  67. return pred
  68. height, width = inputs['org_shape']
  69. segmentation = pred > self.thresh
  70. if self.return_polygon:
  71. boxes, scores = polygons_from_bitmap(pred, segmentation, width,
  72. height)
  73. else:
  74. boxes, scores = boxes_from_bitmap(pred, segmentation, width,
  75. height)
  76. result = {'det_polygons': np.array(boxes)}
  77. return result