visual.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import os
  16. import numpy as np
  17. import PIL
  18. from PIL import Image, ImageDraw, ImageFont
  19. def draw_ser_results(
  20. image, ocr_results, font_path="doc/fonts/simfang.ttf", font_size=14
  21. ):
  22. np.random.seed(2021)
  23. color = (
  24. np.random.permutation(range(255)),
  25. np.random.permutation(range(255)),
  26. np.random.permutation(range(255)),
  27. )
  28. color_map = {
  29. idx: (color[0][idx], color[1][idx], color[2][idx]) for idx in range(1, 255)
  30. }
  31. if isinstance(image, np.ndarray):
  32. image = Image.fromarray(image)
  33. elif isinstance(image, str) and os.path.isfile(image):
  34. image = Image.open(image).convert("RGB")
  35. img_new = image.copy()
  36. draw = ImageDraw.Draw(img_new)
  37. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  38. for ocr_info in ocr_results:
  39. if ocr_info["pred_id"] not in color_map:
  40. continue
  41. color = color_map[ocr_info["pred_id"]]
  42. text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
  43. if "bbox" in ocr_info:
  44. # draw with ocr engine
  45. bbox = ocr_info["bbox"]
  46. else:
  47. # draw with ocr groundtruth
  48. bbox = trans_poly_to_bbox(ocr_info["points"])
  49. draw_box_txt(bbox, text, draw, font, font_size, color)
  50. img_new = Image.blend(image, img_new, 0.7)
  51. return np.array(img_new)
  52. def draw_box_txt(bbox, text, draw, font, font_size, color):
  53. # draw ocr results outline
  54. bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
  55. draw.rectangle(bbox, fill=color)
  56. # draw ocr results
  57. if int(PIL.__version__.split(".")[0]) < 10:
  58. tw = font.getsize(text)[0]
  59. th = font.getsize(text)[1]
  60. else:
  61. left, top, right, bottom = font.getbbox(text)
  62. tw, th = right - left, bottom - top
  63. start_y = max(0, bbox[0][1] - th)
  64. draw.rectangle(
  65. [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + th)],
  66. fill=(0, 0, 255),
  67. )
  68. draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
  69. def trans_poly_to_bbox(poly):
  70. x1 = np.min([p[0] for p in poly])
  71. x2 = np.max([p[0] for p in poly])
  72. y1 = np.min([p[1] for p in poly])
  73. y2 = np.max([p[1] for p in poly])
  74. return [x1, y1, x2, y2]
  75. def draw_re_results(image, result, font_path="doc/fonts/simfang.ttf", font_size=18):
  76. np.random.seed(0)
  77. if isinstance(image, np.ndarray):
  78. image = Image.fromarray(image)
  79. elif isinstance(image, str) and os.path.isfile(image):
  80. image = Image.open(image).convert("RGB")
  81. img_new = image.copy()
  82. draw = ImageDraw.Draw(img_new)
  83. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  84. color_head = (0, 0, 255)
  85. color_tail = (255, 0, 0)
  86. color_line = (0, 255, 0)
  87. for ocr_info_head, ocr_info_tail in result:
  88. draw_box_txt(
  89. ocr_info_head["bbox"],
  90. ocr_info_head["transcription"],
  91. draw,
  92. font,
  93. font_size,
  94. color_head,
  95. )
  96. draw_box_txt(
  97. ocr_info_tail["bbox"],
  98. ocr_info_tail["transcription"],
  99. draw,
  100. font,
  101. font_size,
  102. color_tail,
  103. )
  104. center_head = (
  105. (ocr_info_head["bbox"][0] + ocr_info_head["bbox"][2]) // 2,
  106. (ocr_info_head["bbox"][1] + ocr_info_head["bbox"][3]) // 2,
  107. )
  108. center_tail = (
  109. (ocr_info_tail["bbox"][0] + ocr_info_tail["bbox"][2]) // 2,
  110. (ocr_info_tail["bbox"][1] + ocr_info_tail["bbox"][3]) // 2,
  111. )
  112. draw.line([center_head, center_tail], fill=color_line, width=5)
  113. img_new = Image.blend(image, img_new, 0.5)
  114. return np.array(img_new)
  115. def draw_rectangle(img_path, boxes):
  116. boxes = np.array(boxes)
  117. img = cv2.imread(img_path)
  118. img_show = img.copy()
  119. for box in boxes.astype(int):
  120. x1, y1, x2, y2 = box
  121. cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
  122. return img_show