test_ocr.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import pytest
  2. from paddleocr import PaddleOCR
  3. from ..testing_utils import (
  4. TEST_DATA_DIR,
  5. check_simple_inference_result,
  6. check_wrapper_simple_inference_param_forwarding,
  7. )
  8. @pytest.fixture(scope="module")
  9. def ocr_engine() -> PaddleOCR:
  10. return PaddleOCR()
  11. # TODO: Should we separate unit tests and integration tests?
  12. @pytest.mark.parametrize(
  13. "image_path",
  14. [
  15. TEST_DATA_DIR / "table.jpg",
  16. ],
  17. )
  18. def test_predict(ocr_engine: PaddleOCR, image_path: str) -> None:
  19. """
  20. Test PaddleOCR's OCR functionality.
  21. Args:
  22. ocr_engine: An instance of `PaddleOCR`.
  23. image_path: Path to the image to be processed.
  24. """
  25. result = ocr_engine.predict(str(image_path))
  26. check_simple_inference_result(result)
  27. res = result[0]
  28. assert len(res["dt_polys"]) > 0
  29. assert isinstance(res["rec_texts"], list)
  30. assert len(res["rec_texts"]) > 0
  31. for text in res["rec_texts"]:
  32. assert isinstance(text, str)
  33. # TODO: Also check passing `None`
  34. @pytest.mark.parametrize(
  35. "params",
  36. [
  37. {"use_doc_orientation_classify": False},
  38. {"use_doc_unwarping": False},
  39. {"use_textline_orientation": False},
  40. {"text_det_limit_side_len": 640, "text_det_limit_type": "min"},
  41. {"text_det_thresh": 0.5},
  42. {"text_det_box_thresh": 0.3},
  43. {"text_det_unclip_ratio": 3.0},
  44. {"text_rec_score_thresh": 0.5},
  45. ],
  46. )
  47. def test_predict_params(
  48. monkeypatch,
  49. ocr_engine: PaddleOCR,
  50. params: dict,
  51. ) -> None:
  52. check_wrapper_simple_inference_param_forwarding(
  53. monkeypatch,
  54. ocr_engine,
  55. "paddlex_pipeline",
  56. "dummy_path",
  57. params,
  58. )
  59. # TODO: Test init params
  60. def test_lang_and_ocr_version():
  61. ocr_engine = PaddleOCR(lang="ch", ocr_version="PP-OCRv5")
  62. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv5_server_det"
  63. assert ocr_engine._params["text_recognition_model_name"] == "PP-OCRv5_server_rec"
  64. ocr_engine = PaddleOCR(lang="chinese_cht", ocr_version="PP-OCRv5")
  65. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv5_server_det"
  66. assert ocr_engine._params["text_recognition_model_name"] == "PP-OCRv5_server_rec"
  67. ocr_engine = PaddleOCR(lang="en", ocr_version="PP-OCRv5")
  68. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv5_server_det"
  69. assert ocr_engine._params["text_recognition_model_name"] == "en_PP-OCRv5_mobile_rec"
  70. ocr_engine = PaddleOCR(lang="japan", ocr_version="PP-OCRv5")
  71. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv5_server_det"
  72. assert ocr_engine._params["text_recognition_model_name"] == "PP-OCRv5_server_rec"
  73. ocr_engine = PaddleOCR(lang="ch", ocr_version="PP-OCRv4")
  74. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv4_mobile_det"
  75. assert ocr_engine._params["text_recognition_model_name"] == "PP-OCRv4_mobile_rec"
  76. ocr_engine = PaddleOCR(lang="en", ocr_version="PP-OCRv4")
  77. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv4_mobile_det"
  78. assert ocr_engine._params["text_recognition_model_name"] == "en_PP-OCRv4_mobile_rec"
  79. ocr_engine = PaddleOCR(lang="ch", ocr_version="PP-OCRv3")
  80. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  81. assert ocr_engine._params["text_recognition_model_name"] == "PP-OCRv3_mobile_rec"
  82. ocr_engine = PaddleOCR(lang="en", ocr_version="PP-OCRv3")
  83. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  84. assert ocr_engine._params["text_recognition_model_name"] == "en_PP-OCRv3_mobile_rec"
  85. ocr_engine = PaddleOCR(lang="fr", ocr_version="PP-OCRv3")
  86. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  87. assert (
  88. ocr_engine._params["text_recognition_model_name"] == "latin_PP-OCRv3_mobile_rec"
  89. )
  90. ocr_engine = PaddleOCR(lang="ar", ocr_version="PP-OCRv3")
  91. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  92. assert (
  93. ocr_engine._params["text_recognition_model_name"]
  94. == "arabic_PP-OCRv3_mobile_rec"
  95. )
  96. ocr_engine = PaddleOCR(lang="ru", ocr_version="PP-OCRv3")
  97. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  98. assert (
  99. ocr_engine._params["text_recognition_model_name"]
  100. == "cyrillic_PP-OCRv3_mobile_rec"
  101. )
  102. ocr_engine = PaddleOCR(lang="hi", ocr_version="PP-OCRv3")
  103. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  104. assert (
  105. ocr_engine._params["text_recognition_model_name"]
  106. == "devanagari_PP-OCRv3_mobile_rec"
  107. )
  108. ocr_engine = PaddleOCR(lang="japan", ocr_version="PP-OCRv3")
  109. assert ocr_engine._params["text_detection_model_name"] == "PP-OCRv3_mobile_det"
  110. assert (
  111. ocr_engine._params["text_recognition_model_name"] == "japan_PP-OCRv3_mobile_rec"
  112. )