table_recognition_pipeline.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os.path as osp
  4. from typing import Any, Dict
  5. import cv2
  6. import numpy as np
  7. import PIL
  8. import torch
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import Input, Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel
  14. from modelscope.pipelines.cv.ocr_utils.table_process import (
  15. bbox_decode, bbox_post_process, gbox_decode, gbox_post_process,
  16. get_affine_transform, group_bbox_by_gbox, nms)
  17. from modelscope.preprocessors import load_image
  18. from modelscope.preprocessors.image import LoadImage
  19. from modelscope.utils.constant import ModelFile, Tasks
  20. from modelscope.utils.logger import get_logger
  21. logger = get_logger()
  22. @PIPELINES.register_module(
  23. Tasks.table_recognition, module_name=Pipelines.table_recognition)
  24. class TableRecognitionPipeline(Pipeline):
  25. def __init__(self, model: str, **kwargs):
  26. """
  27. Args:
  28. model: model id on modelscope hub.
  29. """
  30. super().__init__(model=model, **kwargs)
  31. model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
  32. logger.info(f'loading model from {model_path}')
  33. self.K = 1000
  34. self.MK = 4000
  35. self.device = torch.device(
  36. 'cuda' if torch.cuda.is_available() else 'cpu')
  37. self.infer_model = TableRecModel().to(self.device)
  38. self.infer_model.eval()
  39. checkpoint = torch.load(
  40. model_path, map_location=self.device, weights_only=True)
  41. if 'state_dict' in checkpoint:
  42. self.infer_model.load_state_dict(checkpoint['state_dict'])
  43. else:
  44. self.infer_model.load_state_dict(checkpoint)
  45. def preprocess(self, input: Input) -> Dict[str, Any]:
  46. img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]
  47. mean = np.array([0.408, 0.447, 0.470],
  48. dtype=np.float32).reshape(1, 1, 3)
  49. std = np.array([0.289, 0.274, 0.278],
  50. dtype=np.float32).reshape(1, 1, 3)
  51. height, width = img.shape[0:2]
  52. inp_height, inp_width = 1024, 1024
  53. c = np.array([width / 2., height / 2.], dtype=np.float32)
  54. s = max(height, width) * 1.0
  55. trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
  56. resized_image = cv2.resize(img, (width, height))
  57. inp_image = cv2.warpAffine(
  58. resized_image,
  59. trans_input, (inp_width, inp_height),
  60. flags=cv2.INTER_LINEAR)
  61. inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)
  62. images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height,
  63. inp_width)
  64. images = torch.from_numpy(images).to(self.device)
  65. meta = {
  66. 'c': c,
  67. 's': s,
  68. 'input_height': inp_height,
  69. 'input_width': inp_width,
  70. 'out_height': inp_height // 4,
  71. 'out_width': inp_width // 4
  72. }
  73. result = {'img': images, 'meta': meta}
  74. return result
  75. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  76. pred = self.infer_model(input['img'])
  77. return {'results': pred, 'meta': input['meta']}
  78. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  79. output = inputs['results'][0]
  80. meta = inputs['meta']
  81. hm = output['hm'].sigmoid_()
  82. v2c = output['v2c']
  83. c2v = output['c2v']
  84. reg = output['reg']
  85. bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K)
  86. gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK)
  87. bbox = bbox.detach().cpu().numpy()
  88. gbox = gbox.detach().cpu().numpy()
  89. bbox = nms(bbox, 0.3)
  90. bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()],
  91. [meta['s']], meta['out_height'],
  92. meta['out_width'])
  93. gbox = gbox_post_process(gbox.copy(), [meta['c'].cpu().numpy()],
  94. [meta['s']], meta['out_height'],
  95. meta['out_width'])
  96. bbox = group_bbox_by_gbox(bbox[0], gbox[0])
  97. res = []
  98. for box in bbox:
  99. if box[8] > 0.3:
  100. res.append(box[0:8])
  101. result = {OutputKeys.POLYGONS: np.array(res)}
  102. return result