predict_rec.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import cv2
  2. import numpy as np
  3. import math
  4. from PIL import Image
  5. from .rec_postprocess import CTCLabelDecode
  6. from .predict_base import PredictBase
  7. class TextRecognizer(PredictBase):
  8. def __init__(self, args):
  9. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  10. self.rec_batch_num = args.rec_batch_num
  11. self.rec_algorithm = args.rec_algorithm
  12. self.postprocess_op = CTCLabelDecode(
  13. character_dict_path=args.rec_char_dict_path,
  14. use_space_char=args.use_space_char,
  15. )
  16. # 初始化模型
  17. self.rec_onnx_session = self.get_onnx_session(args.rec_model_dir, args.use_gpu)
  18. self.rec_input_name = self.get_input_name(self.rec_onnx_session)
  19. self.rec_output_name = self.get_output_name(self.rec_onnx_session)
  20. def resize_norm_img(self, img, max_wh_ratio):
  21. imgC, imgH, imgW = self.rec_image_shape
  22. if self.rec_algorithm == "NRTR" or self.rec_algorithm == "ViTSTR":
  23. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  24. # return padding_im
  25. image_pil = Image.fromarray(np.uint8(img))
  26. if self.rec_algorithm == "ViTSTR":
  27. img = image_pil.resize([imgW, imgH], Image.BICUBIC)
  28. else:
  29. img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
  30. img = np.array(img)
  31. norm_img = np.expand_dims(img, -1)
  32. norm_img = norm_img.transpose((2, 0, 1))
  33. if self.rec_algorithm == "ViTSTR":
  34. norm_img = norm_img.astype(np.float32) / 255.0
  35. else:
  36. norm_img = norm_img.astype(np.float32) / 128.0 - 1.0
  37. return norm_img
  38. elif self.rec_algorithm == "RFL":
  39. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  40. resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
  41. resized_image = resized_image.astype("float32")
  42. resized_image = resized_image / 255
  43. resized_image = resized_image[np.newaxis, :]
  44. resized_image -= 0.5
  45. resized_image /= 0.5
  46. return resized_image
  47. assert imgC == img.shape[2]
  48. imgW = int((imgH * max_wh_ratio))
  49. # w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
  50. # w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
  51. # print(w)
  52. # if w is not None and w > 0:
  53. # imgW = w
  54. h, w = img.shape[:2]
  55. ratio = w / float(h)
  56. if math.ceil(imgH * ratio) > imgW:
  57. resized_w = imgW
  58. else:
  59. resized_w = int(math.ceil(imgH * ratio))
  60. if self.rec_algorithm == "RARE":
  61. if resized_w > self.rec_image_shape[2]:
  62. resized_w = self.rec_image_shape[2]
  63. imgW = self.rec_image_shape[2]
  64. resized_image = cv2.resize(img, (resized_w, imgH))
  65. resized_image = resized_image.astype("float32")
  66. resized_image = resized_image.transpose((2, 0, 1)) / 255
  67. resized_image -= 0.5
  68. resized_image /= 0.5
  69. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  70. padding_im[:, :, 0:resized_w] = resized_image
  71. return padding_im
  72. def resize_norm_img_vl(self, img, image_shape):
  73. imgC, imgH, imgW = image_shape
  74. img = img[:, :, ::-1] # bgr2rgb
  75. resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
  76. resized_image = resized_image.astype("float32")
  77. resized_image = resized_image.transpose((2, 0, 1)) / 255
  78. return resized_image
  79. def resize_norm_img_srn(self, img, image_shape):
  80. imgC, imgH, imgW = image_shape
  81. img_black = np.zeros((imgH, imgW))
  82. im_hei = img.shape[0]
  83. im_wid = img.shape[1]
  84. if im_wid <= im_hei * 1:
  85. img_new = cv2.resize(img, (imgH * 1, imgH))
  86. elif im_wid <= im_hei * 2:
  87. img_new = cv2.resize(img, (imgH * 2, imgH))
  88. elif im_wid <= im_hei * 3:
  89. img_new = cv2.resize(img, (imgH * 3, imgH))
  90. else:
  91. img_new = cv2.resize(img, (imgW, imgH))
  92. img_np = np.asarray(img_new)
  93. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  94. img_black[:, 0 : img_np.shape[1]] = img_np
  95. img_black = img_black[:, :, np.newaxis]
  96. row, col, c = img_black.shape
  97. c = 1
  98. return np.reshape(img_black, (c, row, col)).astype(np.float32)
  99. def srn_other_inputs(self, image_shape, num_heads, max_text_length):
  100. imgC, imgH, imgW = image_shape
  101. feature_dim = int((imgH / 8) * (imgW / 8))
  102. encoder_word_pos = (
  103. np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
  104. )
  105. gsrm_word_pos = (
  106. np.array(range(0, max_text_length))
  107. .reshape((max_text_length, 1))
  108. .astype("int64")
  109. )
  110. gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
  111. gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
  112. [-1, 1, max_text_length, max_text_length]
  113. )
  114. gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype(
  115. "float32"
  116. ) * [-1e9]
  117. gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
  118. [-1, 1, max_text_length, max_text_length]
  119. )
  120. gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype(
  121. "float32"
  122. ) * [-1e9]
  123. encoder_word_pos = encoder_word_pos[np.newaxis, :]
  124. gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
  125. return [
  126. encoder_word_pos,
  127. gsrm_word_pos,
  128. gsrm_slf_attn_bias1,
  129. gsrm_slf_attn_bias2,
  130. ]
  131. def process_image_srn(self, img, image_shape, num_heads, max_text_length):
  132. norm_img = self.resize_norm_img_srn(img, image_shape)
  133. norm_img = norm_img[np.newaxis, :]
  134. [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = (
  135. self.srn_other_inputs(image_shape, num_heads, max_text_length)
  136. )
  137. gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
  138. gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
  139. encoder_word_pos = encoder_word_pos.astype(np.int64)
  140. gsrm_word_pos = gsrm_word_pos.astype(np.int64)
  141. return (
  142. norm_img,
  143. encoder_word_pos,
  144. gsrm_word_pos,
  145. gsrm_slf_attn_bias1,
  146. gsrm_slf_attn_bias2,
  147. )
  148. def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25):
  149. imgC, imgH, imgW_min, imgW_max = image_shape
  150. h = img.shape[0]
  151. w = img.shape[1]
  152. valid_ratio = 1.0
  153. # make sure new_width is an integral multiple of width_divisor.
  154. width_divisor = int(1 / width_downsample_ratio)
  155. # resize
  156. ratio = w / float(h)
  157. resize_w = math.ceil(imgH * ratio)
  158. if resize_w % width_divisor != 0:
  159. resize_w = round(resize_w / width_divisor) * width_divisor
  160. if imgW_min is not None:
  161. resize_w = max(imgW_min, resize_w)
  162. if imgW_max is not None:
  163. valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
  164. resize_w = min(imgW_max, resize_w)
  165. resized_image = cv2.resize(img, (resize_w, imgH))
  166. resized_image = resized_image.astype("float32")
  167. # norm
  168. if image_shape[0] == 1:
  169. resized_image = resized_image / 255
  170. resized_image = resized_image[np.newaxis, :]
  171. else:
  172. resized_image = resized_image.transpose((2, 0, 1)) / 255
  173. resized_image -= 0.5
  174. resized_image /= 0.5
  175. resize_shape = resized_image.shape
  176. padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
  177. padding_im[:, :, 0:resize_w] = resized_image
  178. pad_shape = padding_im.shape
  179. return padding_im, resize_shape, pad_shape, valid_ratio
  180. def resize_norm_img_spin(self, img):
  181. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  182. # return padding_im
  183. img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
  184. img = np.array(img, np.float32)
  185. img = np.expand_dims(img, -1)
  186. img = img.transpose((2, 0, 1))
  187. mean = [127.5]
  188. std = [127.5]
  189. mean = np.array(mean, dtype=np.float32)
  190. std = np.array(std, dtype=np.float32)
  191. mean = np.float32(mean.reshape(1, -1))
  192. stdinv = 1 / np.float32(std.reshape(1, -1))
  193. img -= mean
  194. img *= stdinv
  195. return img
  196. def resize_norm_img_svtr(self, img, image_shape):
  197. imgC, imgH, imgW = image_shape
  198. resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
  199. resized_image = resized_image.astype("float32")
  200. resized_image = resized_image.transpose((2, 0, 1)) / 255
  201. resized_image -= 0.5
  202. resized_image /= 0.5
  203. return resized_image
  204. def resize_norm_img_abinet(self, img, image_shape):
  205. imgC, imgH, imgW = image_shape
  206. resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
  207. resized_image = resized_image.astype("float32")
  208. resized_image = resized_image / 255.0
  209. mean = np.array([0.485, 0.456, 0.406])
  210. std = np.array([0.229, 0.224, 0.225])
  211. resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
  212. resized_image = resized_image.transpose((2, 0, 1))
  213. resized_image = resized_image.astype("float32")
  214. return resized_image
  215. def norm_img_can(self, img, image_shape):
  216. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
  217. if self.inverse:
  218. img = 255 - img
  219. if self.rec_image_shape[0] == 1:
  220. h, w = img.shape
  221. _, imgH, imgW = self.rec_image_shape
  222. if h < imgH or w < imgW:
  223. padding_h = max(imgH - h, 0)
  224. padding_w = max(imgW - w, 0)
  225. img_padded = np.pad(
  226. img,
  227. ((0, padding_h), (0, padding_w)),
  228. "constant",
  229. constant_values=(255),
  230. )
  231. img = img_padded
  232. img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
  233. img = img.astype("float32")
  234. return img
  235. def __call__(self, img_list):
  236. img_num = len(img_list)
  237. # Calculate the aspect ratio of all text bars
  238. width_list = []
  239. for img in img_list:
  240. width_list.append(img.shape[1] / float(img.shape[0]))
  241. # Sorting can speed up the recognition process
  242. indices = np.argsort(np.array(width_list))
  243. rec_res = [["", 0.0]] * img_num
  244. batch_num = self.rec_batch_num
  245. for beg_img_no in range(0, img_num, batch_num):
  246. end_img_no = min(img_num, beg_img_no + batch_num)
  247. norm_img_batch = []
  248. imgC, imgH, imgW = self.rec_image_shape[:3]
  249. max_wh_ratio = imgW / imgH
  250. # max_wh_ratio = 0
  251. for ino in range(beg_img_no, end_img_no):
  252. h, w = img_list[indices[ino]].shape[0:2]
  253. wh_ratio = w * 1.0 / h
  254. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  255. for ino in range(beg_img_no, end_img_no):
  256. norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
  257. norm_img = norm_img[np.newaxis, :]
  258. norm_img_batch.append(norm_img)
  259. norm_img_batch = np.concatenate(norm_img_batch)
  260. norm_img_batch = norm_img_batch.copy()
  261. # img = img[:, :, ::-1].transpose(2, 0, 1)
  262. # img = img[:, :, ::-1]
  263. # img = img.transpose(2, 0, 1)
  264. # img = img.astype(np.float32)
  265. # img = np.expand_dims(img, axis=0)
  266. # print(img.shape)
  267. input_feed = self.get_input_feed(self.rec_input_name, norm_img_batch)
  268. outputs = self.rec_onnx_session.run(
  269. self.rec_output_name, input_feed=input_feed
  270. )
  271. preds = outputs[0]
  272. rec_result = self.postprocess_op(preds)
  273. for rno in range(len(rec_result)):
  274. rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  275. return rec_res