ocr_detection_pipeline.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os
  4. import os.path as osp
  5. from typing import Any, Dict
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.cv.ocr_detection import OCRDetection
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Input, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.preprocessors import LoadImage
  15. from modelscope.utils.config import Config
  16. from modelscope.utils.constant import ModelFile, Tasks
  17. from modelscope.utils.device import device_placement
  18. from modelscope.utils.logger import get_logger
  19. from .ocr_utils import cal_width, nms_python, rboxes_to_polygons
  20. logger = get_logger()
  21. # constant
  22. RBOX_DIM = 5
  23. OFFSET_DIM = 6
  24. WORD_POLYGON_DIM = 8
  25. OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
  26. TF_NODE_THRESHOLD = 0.4
  27. TF_LINK_THRESHOLD = 0.6
  28. @PIPELINES.register_module(
  29. Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
  30. class OCRDetectionPipeline(Pipeline):
  31. """ OCR Detection Pipeline.
  32. Example:
  33. ```python
  34. >>> from modelscope.pipelines import pipeline
  35. >>> ocr_detection = pipeline('ocr-detection', model='damo/cv_resnet18_ocr-detection-line-level_damo')
  36. >>> result = ocr_detection('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/ocr_detection.jpg')
  37. {'polygons': array([[220, 14, 780, 14, 780, 64, 220, 64],
  38. [196, 369, 604, 370, 604, 425, 196, 425],
  39. [ 21, 730, 425, 731, 425, 787, 21, 786],
  40. [421, 731, 782, 731, 782, 789, 421, 789],
  41. [ 0, 121, 109, 0, 147, 35, 26, 159],
  42. [697, 160, 773, 160, 773, 197, 697, 198],
  43. [547, 205, 623, 205, 623, 244, 547, 244],
  44. [548, 161, 623, 161, 623, 199, 547, 199],
  45. [698, 206, 772, 206, 772, 244, 698, 244]])}
  46. ```
  47. note:
  48. model = damo/cv_resnet18_ocr-detection-line-level_damo, for general text line detection, based on SegLink++.
  49. model = damo/cv_resnet18_ocr-detection-word-level_damo, for general text word detection, based on SegLink++.
  50. model = damo/cv_resnet50_ocr-detection-vlpt, for toaltext dataset, based on VLPT_pretrained DBNet.
  51. model = damo/cv_resnet18_ocr-detection-db-line-level_damo, for general text line detection, based on DBNet.
  52. """
  53. def __init__(self, model: str, **kwargs):
  54. """
  55. use `model` to create a OCR detection pipeline for prediction
  56. Args:
  57. model: model id on modelscope hub.
  58. """
  59. assert isinstance(model, str), 'model must be a single str'
  60. super().__init__(model=model, **kwargs)
  61. logger.info(f'loading model from dir {model}')
  62. cfgs = Config.from_file(os.path.join(model, ModelFile.CONFIGURATION))
  63. if hasattr(cfgs, 'model') and hasattr(cfgs.model, 'model_type'):
  64. self.model_type = cfgs.model.model_type
  65. else:
  66. self.model_type = 'SegLink++'
  67. if self.model_type == 'DBNet':
  68. self.ocr_detector = self.model.to(self.device)
  69. self.ocr_detector.eval()
  70. logger.info('loading model done')
  71. else:
  72. # for model seglink++
  73. import tensorflow as tf
  74. if tf.__version__ >= '2.0':
  75. tf = tf.compat.v1
  76. tf.compat.v1.disable_eager_execution()
  77. tf.app.flags.DEFINE_float('node_threshold', TF_NODE_THRESHOLD,
  78. 'Confidence threshold for nodes')
  79. tf.app.flags.DEFINE_float('link_threshold', TF_LINK_THRESHOLD,
  80. 'Confidence threshold for links')
  81. tf.reset_default_graph()
  82. model_path = osp.join(
  83. osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
  84. 'checkpoint-80000')
  85. self._graph = tf.get_default_graph()
  86. config = tf.ConfigProto(allow_soft_placement=True)
  87. config.gpu_options.allow_growth = True
  88. self._session = tf.Session(config=config)
  89. with self._graph.as_default():
  90. with device_placement(self.framework, self.device_name):
  91. self.input_images = tf.placeholder(
  92. tf.float32,
  93. shape=[1, 1024, 1024, 3],
  94. name='input_images')
  95. self.output = {}
  96. with tf.variable_scope('', reuse=tf.AUTO_REUSE):
  97. global_step = tf.get_variable(
  98. 'global_step', [],
  99. initializer=tf.constant_initializer(0),
  100. dtype=tf.int64,
  101. trainable=False)
  102. variable_averages = tf.train.ExponentialMovingAverage(
  103. 0.997, global_step)
  104. from .ocr_utils import SegLinkDetector, combine_segments_python, decode_segments_links_python
  105. # detector
  106. detector = SegLinkDetector()
  107. all_maps = detector.build_model(
  108. self.input_images, is_training=False)
  109. # decode local predictions
  110. all_nodes, all_links, all_reg = [], [], []
  111. for i, maps in enumerate(all_maps):
  112. cls_maps, lnk_maps, reg_maps = maps[0], maps[
  113. 1], maps[2]
  114. reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)
  115. cls_prob = tf.nn.softmax(
  116. tf.reshape(cls_maps, [-1, 2]))
  117. lnk_prob_pos = tf.nn.softmax(
  118. tf.reshape(lnk_maps, [-1, 4])[:, :2])
  119. lnk_prob_mut = tf.nn.softmax(
  120. tf.reshape(lnk_maps, [-1, 4])[:, 2:])
  121. lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut],
  122. axis=1)
  123. all_nodes.append(cls_prob)
  124. all_links.append(lnk_prob)
  125. all_reg.append(reg_maps)
  126. # decode segments and links
  127. image_size = tf.shape(self.input_images)[1:3]
  128. segments, group_indices, segment_counts, _ = decode_segments_links_python(
  129. image_size,
  130. all_nodes,
  131. all_links,
  132. all_reg,
  133. anchor_sizes=list(detector.anchor_sizes))
  134. # combine segments
  135. combined_rboxes, combined_counts = combine_segments_python(
  136. segments, group_indices, segment_counts)
  137. self.output['combined_rboxes'] = combined_rboxes
  138. self.output['combined_counts'] = combined_counts
  139. with self._session.as_default() as sess:
  140. logger.info(f'loading model from {model_path}')
  141. # load model
  142. model_loader = tf.train.Saver(
  143. variable_averages.variables_to_restore())
  144. model_loader.restore(sess, model_path)
  145. def __call__(self, input, **kwargs):
  146. """
  147. Detect text instance in the text image.
  148. Args:
  149. input (`Image`):
  150. The pipeline handles three types of images:
  151. - A string containing an HTTP link pointing to an image
  152. - A string containing a local path to an image
  153. - An image loaded in PIL or opencv directly
  154. The pipeline currently supports single image input.
  155. Return:
  156. An array of contour polygons of detected N text instances in image,
  157. every row is [x1, y1, x2, y2, x3, y3, x4, y4, ...].
  158. """
  159. return super().__call__(input, **kwargs)
  160. def preprocess(self, input: Input) -> Dict[str, Any]:
  161. if self.model_type == 'DBNet':
  162. result = self.preprocessor(input)
  163. return result
  164. else:
  165. # for model seglink++
  166. img = LoadImage.convert_to_ndarray(input)
  167. h, w, c = img.shape
  168. img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32)
  169. img_pad[:h, :w, :] = img
  170. resize_size = 1024
  171. img_pad_resize = cv2.resize(img_pad, (resize_size, resize_size))
  172. img_pad_resize = cv2.cvtColor(img_pad_resize, cv2.COLOR_RGB2BGR)
  173. img_pad_resize = img_pad_resize - np.array(
  174. [123.68, 116.78, 103.94], dtype=np.float32)
  175. import tensorflow as tf
  176. with self._graph.as_default():
  177. resize_size = tf.stack([resize_size, resize_size])
  178. orig_size = tf.stack([max(h, w), max(h, w)])
  179. self.output['orig_size'] = orig_size
  180. self.output['resize_size'] = resize_size
  181. result = {'img': np.expand_dims(img_pad_resize, axis=0)}
  182. return result
  183. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  184. if self.model_type == 'DBNet':
  185. outputs = self.ocr_detector(input)
  186. return outputs
  187. else:
  188. with self._graph.as_default():
  189. with self._session.as_default():
  190. feed_dict = {self.input_images: input['img']}
  191. sess_outputs = self._session.run(
  192. self.output, feed_dict=feed_dict)
  193. return sess_outputs
  194. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  195. if self.model_type == 'DBNet':
  196. result = {OutputKeys.POLYGONS: inputs['det_polygons']}
  197. return result
  198. else:
  199. rboxes = inputs['combined_rboxes'][0]
  200. count = inputs['combined_counts'][0]
  201. if count == 0 or count < rboxes.shape[0]:
  202. raise Exception('modelscope error: No text detected')
  203. rboxes = rboxes[:count, :]
  204. # convert rboxes to polygons and find its coordinates on the original image
  205. orig_h, orig_w = inputs['orig_size']
  206. resize_h, resize_w = inputs['resize_size']
  207. polygons = rboxes_to_polygons(rboxes)
  208. scale_y = float(orig_h) / float(resize_h)
  209. scale_x = float(orig_w) / float(resize_w)
  210. # confine polygons inside image
  211. polygons[:, ::2] = np.maximum(
  212. 0, np.minimum(polygons[:, ::2] * scale_x, orig_w - 1))
  213. polygons[:, 1::2] = np.maximum(
  214. 0, np.minimum(polygons[:, 1::2] * scale_y, orig_h - 1))
  215. polygons = np.round(polygons).astype(np.int32)
  216. # nms
  217. dt_n9 = [o + [cal_width(o)] for o in polygons.tolist()]
  218. dt_nms = nms_python(dt_n9)
  219. dt_polygons = np.array([o[:8] for o in dt_nms])
  220. result = {OutputKeys.POLYGONS: dt_polygons}
  221. return result