onnx_paddleocr.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import time
  2. from .predict_system import TextSystem
  3. from .utils import infer_args as init_args
  4. from .utils import str2bool, draw_ocr
  5. import argparse
  6. import sys
  7. class ONNXPaddleOcr(TextSystem):
  8. def __init__(self, **kwargs):
  9. # 默认参数
  10. parser = init_args()
  11. inference_args_dict = {}
  12. for action in parser._actions:
  13. inference_args_dict[action.dest] = action.default
  14. params = argparse.Namespace(**inference_args_dict)
  15. # params.rec_image_shape = "3, 32, 320"
  16. params.rec_image_shape = "3, 48, 320"
  17. # 根据传入的参数覆盖更新默认参数
  18. params.__dict__.update(**kwargs)
  19. # 初始化模型
  20. super().__init__(params)
  21. def ocr(self, img, det=True, rec=True, cls=True):
  22. if cls == True and self.use_angle_cls == False:
  23. print(
  24. "Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process"
  25. )
  26. if det and rec:
  27. ocr_res = []
  28. dt_boxes, rec_res = self.__call__(img, cls)
  29. tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  30. ocr_res.append(tmp_res)
  31. return ocr_res
  32. elif det and not rec:
  33. ocr_res = []
  34. dt_boxes = self.text_detector(img)
  35. tmp_res = [box.tolist() for box in dt_boxes]
  36. ocr_res.append(tmp_res)
  37. return ocr_res
  38. else:
  39. ocr_res = []
  40. cls_res = []
  41. if not isinstance(img, list):
  42. img = [img]
  43. if self.use_angle_cls and cls:
  44. img, cls_res_tmp = self.text_classifier(img)
  45. if not rec:
  46. cls_res.append(cls_res_tmp)
  47. rec_res = self.text_recognizer(img)
  48. ocr_res.append(rec_res)
  49. if not rec:
  50. return cls_res
  51. return ocr_res
  52. def sav2Img(org_img, result, name="draw_ocr.jpg"):
  53. # 显示结果
  54. from PIL import Image
  55. result = result[0]
  56. # image = Image.open(img_path).convert('RGB')
  57. # 图像转BGR2RGB
  58. image = org_img[:, :, ::-1]
  59. boxes = [line[0] for line in result]
  60. txts = [line[1][0] for line in result]
  61. scores = [line[1][1] for line in result]
  62. im_show = draw_ocr(image, boxes, txts, scores)
  63. im_show = Image.fromarray(im_show)
  64. im_show.save(name)
  65. if __name__ == "__main__":
  66. import cv2
  67. model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
  68. img = cv2.imread(
  69. "/data2/liujingsong3/fiber_box/test/img/20230531230052008263304.jpg"
  70. )
  71. s = time.time()
  72. result = model.ocr(img)
  73. e = time.time()
  74. print("total time: {:.3f}".format(e - s))
  75. print("result:", result)
  76. for box in result[0]:
  77. print(box)
  78. sav2Img(img, result)