test_rec_postprocess.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. import sys
  3. import numpy as np
  4. import pytest
  5. current_dir = os.path.dirname(os.path.abspath(__file__))
  6. sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
  7. from ppocr.postprocess.rec_postprocess import BaseRecLabelDecode
  8. class TestBaseRecLabelDecode:
  9. """Tests for BaseRecLabelDecode.get_word_info() method."""
  10. @pytest.fixture
  11. def decoder(self):
  12. """Create a BaseRecLabelDecode instance for testing."""
  13. return BaseRecLabelDecode()
  14. def test_get_word_info_with_german_accented_chars(self, decoder):
  15. """Test that German words with accented characters are not split."""
  16. text = "Grüßen"
  17. selection = np.ones(len(text), dtype=bool)
  18. word_list, _, state_list = decoder.get_word_info(text, selection)
  19. assert len(word_list) == 1, "German word should not be split"
  20. assert "".join(word_list[0]) == "Grüßen"
  21. assert state_list[0] == "en&num"
  22. def test_get_word_info_with_longer_german_word(self, decoder):
  23. """Test longer German words with umlauts remain intact."""
  24. text = "ungewöhnlichen"
  25. selection = np.ones(len(text), dtype=bool)
  26. word_list, _, state_list = decoder.get_word_info(text, selection)
  27. assert len(word_list) == 1, "German word should not be split"
  28. assert "".join(word_list[0]) == "ungewöhnlichen"
  29. assert state_list[0] == "en&num"
  30. def test_get_word_info_with_french_accented_chars(self, decoder):
  31. """Test French words with accented characters."""
  32. text = "café"
  33. selection = np.ones(len(text), dtype=bool)
  34. word_list, _, state_list = decoder.get_word_info(text, selection)
  35. assert len(word_list) == 1, "French word should not be split"
  36. assert "".join(word_list[0]) == "café"
  37. def test_get_word_info_underscore_as_splitter(self, decoder):
  38. """Test that underscores are treated as word splitters."""
  39. text = "hello_world"
  40. selection = np.ones(len(text), dtype=bool)
  41. word_list, _, state_list = decoder.get_word_info(text, selection)
  42. assert len(word_list) == 2, "Underscore should split words"
  43. assert "".join(word_list[0]) == "hello"
  44. assert "".join(word_list[1]) == "world"
  45. def test_get_word_info_with_mixed_content(self, decoder):
  46. """Test mixed content with spaces and accented characters."""
  47. text = "Grüßen Sie"
  48. selection = np.ones(len(text), dtype=bool)
  49. word_list, _, state_list = decoder.get_word_info(text, selection)
  50. assert len(word_list) == 2, "Should have two words separated by space"
  51. assert "".join(word_list[0]) == "Grüßen"
  52. assert "".join(word_list[1]) == "Sie"
  53. def test_get_word_info_with_french_apostrophe(self, decoder):
  54. """Test French words with apostrophes like n'êtes."""
  55. text = "n'êtes"
  56. selection = np.ones(len(text), dtype=bool)
  57. word_list, _, state_list = decoder.get_word_info(text, selection)
  58. # Apostrophe should keep words connected in French context
  59. assert len(word_list) == 1, "French apostrophe should connect words"
  60. assert "".join(word_list[0]) == "n'êtes"
  61. def test_get_word_info_with_ascii_only(self, decoder):
  62. """Test backward compatibility with ASCII-only text."""
  63. text = "hello world"
  64. selection = np.ones(len(text), dtype=bool)
  65. word_list, _, state_list = decoder.get_word_info(text, selection)
  66. assert len(word_list) == 2
  67. assert "".join(word_list[0]) == "hello"
  68. assert "".join(word_list[1]) == "world"
  69. def test_get_word_info_with_numbers(self, decoder):
  70. """Test that numbers are properly handled."""
  71. text = "VGG-16"
  72. selection = np.ones(len(text), dtype=bool)
  73. word_list, _, state_list = decoder.get_word_info(text, selection)
  74. assert len(word_list) == 1, "Hyphenated word-number should stay together"
  75. assert "".join(word_list[0]) == "VGG-16"
  76. def test_get_word_info_with_floating_point(self, decoder):
  77. """Test floating point numbers stay together."""
  78. text = "price 3.14"
  79. selection = np.ones(len(text), dtype=bool)
  80. word_list, _, state_list = decoder.get_word_info(text, selection)
  81. assert len(word_list) == 2
  82. assert "".join(word_list[0]) == "price"
  83. assert "".join(word_list[1]) == "3.14"
  84. def test_get_word_info_with_chinese(self, decoder):
  85. """Test Chinese characters are properly grouped."""
  86. text = "你好啊"
  87. selection = np.ones(len(text), dtype=bool)
  88. word_list, _, state_list = decoder.get_word_info(text, selection)
  89. assert len(word_list) == 1
  90. assert "".join(word_list[0]) == "你好啊"
  91. assert state_list[0] == "cn"