utility.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. import cv2
  17. import random
  18. import numpy as np
  19. import paddle
  20. import importlib.util
  21. import sys
  22. import subprocess
  23. def print_dict(d, logger, delimiter=0):
  24. """
  25. Recursively visualize a dict and
  26. indenting acrrording by the relationship of keys.
  27. """
  28. for k, v in sorted(d.items()):
  29. if isinstance(v, dict):
  30. logger.info("{}{} : ".format(delimiter * " ", str(k)))
  31. print_dict(v, logger, delimiter + 4)
  32. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  33. logger.info("{}{} : ".format(delimiter * " ", str(k)))
  34. for value in v:
  35. print_dict(value, logger, delimiter + 4)
  36. else:
  37. logger.info("{}{} : {}".format(delimiter * " ", k, v))
  38. def get_check_global_params(mode):
  39. check_params = [
  40. "use_gpu",
  41. "max_text_length",
  42. "image_shape",
  43. "image_shape",
  44. "character_type",
  45. "loss_type",
  46. ]
  47. if mode == "train_eval":
  48. check_params = check_params + [
  49. "train_batch_size_per_card",
  50. "test_batch_size_per_card",
  51. ]
  52. elif mode == "test":
  53. check_params = check_params + ["test_batch_size_per_card"]
  54. return check_params
  55. def _check_image_file(path):
  56. img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
  57. return any([path.lower().endswith(e) for e in img_end])
  58. def get_image_file_list(img_file, infer_list=None):
  59. imgs_lists = []
  60. if infer_list and not os.path.exists(infer_list):
  61. raise Exception("not found infer list {}".format(infer_list))
  62. if infer_list:
  63. with open(infer_list, "r") as f:
  64. lines = f.readlines()
  65. for line in lines:
  66. image_path = line.strip().split("\t")[0]
  67. image_path = os.path.join(img_file, image_path)
  68. imgs_lists.append(image_path)
  69. else:
  70. if img_file is None or not os.path.exists(img_file):
  71. raise Exception("not found any img file in {}".format(img_file))
  72. img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
  73. if os.path.isfile(img_file) and _check_image_file(img_file):
  74. imgs_lists.append(img_file)
  75. elif os.path.isdir(img_file):
  76. for single_file in os.listdir(img_file):
  77. file_path = os.path.join(img_file, single_file)
  78. if os.path.isfile(file_path) and _check_image_file(file_path):
  79. imgs_lists.append(file_path)
  80. if len(imgs_lists) == 0:
  81. raise Exception("not found any img file in {}".format(img_file))
  82. imgs_lists = sorted(imgs_lists)
  83. return imgs_lists
  84. def binarize_img(img):
  85. if len(img.shape) == 3 and img.shape[2] == 3:
  86. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
  87. # use cv2 threshold binarization
  88. _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  89. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  90. return img
  91. def alpha_to_color(img, alpha_color=(255, 255, 255)):
  92. if len(img.shape) == 3 and img.shape[2] == 4:
  93. B, G, R, A = cv2.split(img)
  94. alpha = A / 255
  95. R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
  96. G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
  97. B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
  98. img = cv2.merge((B, G, R))
  99. return img
  100. def check_and_read(img_path):
  101. if os.path.basename(img_path)[-3:].lower() == "gif":
  102. gif = cv2.VideoCapture(img_path)
  103. ret, frame = gif.read()
  104. if not ret:
  105. logger = logging.getLogger("ppocr")
  106. logger.info("Cannot read {}. This gif image maybe corrupted.")
  107. return None, False
  108. if len(frame.shape) == 2 or frame.shape[-1] == 1:
  109. frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
  110. imgvalue = frame[:, :, ::-1]
  111. return imgvalue, True, False
  112. elif os.path.basename(img_path)[-3:].lower() == "pdf":
  113. from paddle.utils import try_import
  114. fitz = try_import("fitz")
  115. from PIL import Image
  116. imgs = []
  117. with fitz.open(img_path) as pdf:
  118. for pg in range(0, pdf.page_count):
  119. page = pdf[pg]
  120. mat = fitz.Matrix(2, 2)
  121. pm = page.get_pixmap(matrix=mat, alpha=False)
  122. # if width or height > 2000 pixels, don't enlarge the image
  123. if pm.width > 2000 or pm.height > 2000:
  124. pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
  125. img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
  126. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  127. imgs.append(img)
  128. return imgs, False, True
  129. return None, False, False
  130. def load_vqa_bio_label_maps(label_map_path):
  131. with open(label_map_path, "r", encoding="utf-8") as fin:
  132. lines = fin.readlines()
  133. old_lines = [line.strip() for line in lines]
  134. lines = ["O"]
  135. for line in old_lines:
  136. # "O" has already been in lines
  137. if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
  138. continue
  139. lines.append(line)
  140. labels = ["O"]
  141. for line in lines[1:]:
  142. labels.append("B-" + line)
  143. labels.append("I-" + line)
  144. label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
  145. id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
  146. return label2id_map, id2label_map
  147. def set_seed(seed=1024):
  148. random.seed(seed)
  149. np.random.seed(seed)
  150. paddle.seed(seed)
  151. def check_install(module_name, install_name):
  152. spec = importlib.util.find_spec(module_name)
  153. if spec is None:
  154. print(f"Warning! The {module_name} module is NOT installed")
  155. print(
  156. f"Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}."
  157. )
  158. python = sys.executable
  159. try:
  160. subprocess.check_call(
  161. [python, "-m", "pip", "install", install_name],
  162. stdout=subprocess.DEVNULL,
  163. )
  164. print(f"The {module_name} module is now installed")
  165. except subprocess.CalledProcessError as exc:
  166. raise Exception(f"Install {module_name} failed, please install manually")
  167. else:
  168. print(f"{module_name} has been installed.")
  169. class AverageMeter:
  170. def __init__(self):
  171. self.reset()
  172. def reset(self):
  173. """reset"""
  174. self.val = 0
  175. self.avg = 0
  176. self.sum = 0
  177. self.count = 0
  178. def update(self, val, n=1):
  179. """update"""
  180. self.val = val
  181. self.sum += val * n
  182. self.count += n
  183. self.avg = self.sum / self.count