# Copyright (c) Alibaba, Inc. and its affiliates. import math import os import os.path as osp from typing import Any, Dict import cv2 import numpy as np import torch from modelscope.metainfo import Pipelines from modelscope.models.cv.ocr_detection import OCRDetection from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger from .ocr_utils import cal_width, nms_python, rboxes_to_polygons logger = get_logger() # constant RBOX_DIM = 5 OFFSET_DIM = 6 WORD_POLYGON_DIM = 8 OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1] TF_NODE_THRESHOLD = 0.4 TF_LINK_THRESHOLD = 0.6 @PIPELINES.register_module( Tasks.ocr_detection, module_name=Pipelines.ocr_detection) class OCRDetectionPipeline(Pipeline): """ OCR Detection Pipeline. Example: ```python >>> from modelscope.pipelines import pipeline >>> ocr_detection = pipeline('ocr-detection', model='damo/cv_resnet18_ocr-detection-line-level_damo') >>> result = ocr_detection('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/ocr_detection.jpg') {'polygons': array([[220, 14, 780, 14, 780, 64, 220, 64], [196, 369, 604, 370, 604, 425, 196, 425], [ 21, 730, 425, 731, 425, 787, 21, 786], [421, 731, 782, 731, 782, 789, 421, 789], [ 0, 121, 109, 0, 147, 35, 26, 159], [697, 160, 773, 160, 773, 197, 697, 198], [547, 205, 623, 205, 623, 244, 547, 244], [548, 161, 623, 161, 623, 199, 547, 199], [698, 206, 772, 206, 772, 244, 698, 244]])} ``` note: model = damo/cv_resnet18_ocr-detection-line-level_damo, for general text line detection, based on SegLink++. model = damo/cv_resnet18_ocr-detection-word-level_damo, for general text word detection, based on SegLink++. model = damo/cv_resnet50_ocr-detection-vlpt, for toaltext dataset, based on VLPT_pretrained DBNet. model = damo/cv_resnet18_ocr-detection-db-line-level_damo, for general text line detection, based on DBNet. """ def __init__(self, model: str, **kwargs): """ use `model` to create a OCR detection pipeline for prediction Args: model: model id on modelscope hub. """ assert isinstance(model, str), 'model must be a single str' super().__init__(model=model, **kwargs) logger.info(f'loading model from dir {model}') cfgs = Config.from_file(os.path.join(model, ModelFile.CONFIGURATION)) if hasattr(cfgs, 'model') and hasattr(cfgs.model, 'model_type'): self.model_type = cfgs.model.model_type else: self.model_type = 'SegLink++' if self.model_type == 'DBNet': self.ocr_detector = self.model.to(self.device) self.ocr_detector.eval() logger.info('loading model done') else: # for model seglink++ import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 tf.compat.v1.disable_eager_execution() tf.app.flags.DEFINE_float('node_threshold', TF_NODE_THRESHOLD, 'Confidence threshold for nodes') tf.app.flags.DEFINE_float('link_threshold', TF_LINK_THRESHOLD, 'Confidence threshold for links') tf.reset_default_graph() model_path = osp.join( osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') self._graph = tf.get_default_graph() config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self._session = tf.Session(config=config) with self._graph.as_default(): with device_placement(self.framework, self.device_name): self.input_images = tf.placeholder( tf.float32, shape=[1, 1024, 1024, 3], name='input_images') self.output = {} with tf.variable_scope('', reuse=tf.AUTO_REUSE): global_step = tf.get_variable( 'global_step', [], initializer=tf.constant_initializer(0), dtype=tf.int64, trainable=False) variable_averages = tf.train.ExponentialMovingAverage( 0.997, global_step) from .ocr_utils import SegLinkDetector, combine_segments_python, decode_segments_links_python # detector detector = SegLinkDetector() all_maps = detector.build_model( self.input_images, is_training=False) # decode local predictions all_nodes, all_links, all_reg = [], [], [] for i, maps in enumerate(all_maps): cls_maps, lnk_maps, reg_maps = maps[0], maps[ 1], maps[2] reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) cls_prob = tf.nn.softmax( tf.reshape(cls_maps, [-1, 2])) lnk_prob_pos = tf.nn.softmax( tf.reshape(lnk_maps, [-1, 4])[:, :2]) lnk_prob_mut = tf.nn.softmax( tf.reshape(lnk_maps, [-1, 4])[:, 2:]) lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) all_nodes.append(cls_prob) all_links.append(lnk_prob) all_reg.append(reg_maps) # decode segments and links image_size = tf.shape(self.input_images)[1:3] segments, group_indices, segment_counts, _ = decode_segments_links_python( image_size, all_nodes, all_links, all_reg, anchor_sizes=list(detector.anchor_sizes)) # combine segments combined_rboxes, combined_counts = combine_segments_python( segments, group_indices, segment_counts) self.output['combined_rboxes'] = combined_rboxes self.output['combined_counts'] = combined_counts with self._session.as_default() as sess: logger.info(f'loading model from {model_path}') # load model model_loader = tf.train.Saver( variable_averages.variables_to_restore()) model_loader.restore(sess, model_path) def __call__(self, input, **kwargs): """ Detect text instance in the text image. Args: input (`Image`): The pipeline handles three types of images: - A string containing an HTTP link pointing to an image - A string containing a local path to an image - An image loaded in PIL or opencv directly The pipeline currently supports single image input. Return: An array of contour polygons of detected N text instances in image, every row is [x1, y1, x2, y2, x3, y3, x4, y4, ...]. """ return super().__call__(input, **kwargs) def preprocess(self, input: Input) -> Dict[str, Any]: if self.model_type == 'DBNet': result = self.preprocessor(input) return result else: # for model seglink++ img = LoadImage.convert_to_ndarray(input) h, w, c = img.shape img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) img_pad[:h, :w, :] = img resize_size = 1024 img_pad_resize = cv2.resize(img_pad, (resize_size, resize_size)) img_pad_resize = cv2.cvtColor(img_pad_resize, cv2.COLOR_RGB2BGR) img_pad_resize = img_pad_resize - np.array( [123.68, 116.78, 103.94], dtype=np.float32) import tensorflow as tf with self._graph.as_default(): resize_size = tf.stack([resize_size, resize_size]) orig_size = tf.stack([max(h, w), max(h, w)]) self.output['orig_size'] = orig_size self.output['resize_size'] = resize_size result = {'img': np.expand_dims(img_pad_resize, axis=0)} return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: if self.model_type == 'DBNet': outputs = self.ocr_detector(input) return outputs else: with self._graph.as_default(): with self._session.as_default(): feed_dict = {self.input_images: input['img']} sess_outputs = self._session.run( self.output, feed_dict=feed_dict) return sess_outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.model_type == 'DBNet': result = {OutputKeys.POLYGONS: inputs['det_polygons']} return result else: rboxes = inputs['combined_rboxes'][0] count = inputs['combined_counts'][0] if count == 0 or count < rboxes.shape[0]: raise Exception('modelscope error: No text detected') rboxes = rboxes[:count, :] # convert rboxes to polygons and find its coordinates on the original image orig_h, orig_w = inputs['orig_size'] resize_h, resize_w = inputs['resize_size'] polygons = rboxes_to_polygons(rboxes) scale_y = float(orig_h) / float(resize_h) scale_x = float(orig_w) / float(resize_w) # confine polygons inside image polygons[:, ::2] = np.maximum( 0, np.minimum(polygons[:, ::2] * scale_x, orig_w - 1)) polygons[:, 1::2] = np.maximum( 0, np.minimum(polygons[:, 1::2] * scale_y, orig_h - 1)) polygons = np.round(polygons).astype(np.int32) # nms dt_n9 = [o + [cal_width(o)] for o in polygons.tolist()] dt_nms = nms_python(dt_n9) dt_polygons = np.array([o[:8] for o in dt_nms]) result = {OutputKeys.POLYGONS: dt_polygons} return result