# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from typing import Any, Dict import cv2 import numpy as np import PIL import torch from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel from modelscope.pipelines.cv.ocr_utils.table_process import ( bbox_decode, bbox_post_process, gbox_decode, gbox_post_process, get_affine_transform, group_bbox_by_gbox, nms) from modelscope.preprocessors import load_image from modelscope.preprocessors.image import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger logger = get_logger() @PIPELINES.register_module( Tasks.table_recognition, module_name=Pipelines.table_recognition) class TableRecognitionPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ Args: model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) logger.info(f'loading model from {model_path}') self.K = 1000 self.MK = 4000 self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.infer_model = TableRecModel().to(self.device) self.infer_model.eval() checkpoint = torch.load( model_path, map_location=self.device, weights_only=True) if 'state_dict' in checkpoint: self.infer_model.load_state_dict(checkpoint['state_dict']) else: self.infer_model.load_state_dict(checkpoint) def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3) height, width = img.shape[0:2] inp_height, inp_width = 1024, 1024 c = np.array([width / 2., height / 2.], dtype=np.float32) s = max(height, width) * 1.0 trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) resized_image = cv2.resize(img, (width, height)) inp_image = cv2.warpAffine( resized_image, trans_input, (inp_width, inp_height), flags=cv2.INTER_LINEAR) inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, inp_width) images = torch.from_numpy(images).to(self.device) meta = { 'c': c, 's': s, 'input_height': inp_height, 'input_width': inp_width, 'out_height': inp_height // 4, 'out_width': inp_width // 4 } result = {'img': images, 'meta': meta} return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: pred = self.infer_model(input['img']) return {'results': pred, 'meta': input['meta']} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: output = inputs['results'][0] meta = inputs['meta'] hm = output['hm'].sigmoid_() v2c = output['v2c'] c2v = output['c2v'] reg = output['reg'] bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K) gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK) bbox = bbox.detach().cpu().numpy() gbox = gbox.detach().cpu().numpy() bbox = nms(bbox, 0.3) bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], [meta['s']], meta['out_height'], meta['out_width']) gbox = gbox_post_process(gbox.copy(), [meta['c'].cpu().numpy()], [meta['s']], meta['out_height'], meta['out_width']) bbox = group_bbox_by_gbox(bbox[0], gbox[0]) res = [] for box in bbox: if box[8] > 0.3: res.append(box[0:8]) result = {OutputKeys.POLYGONS: np.array(res)} return result