test_cls_postprocess.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. import sys
  3. import numpy as np
  4. import paddle
  5. import pytest
  6. current_dir = os.path.dirname(os.path.abspath(__file__))
  7. sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
  8. from ppocr.postprocess.cls_postprocess import ClsPostProcess
  9. # Fixtures for common test inputs
  10. @pytest.fixture
  11. def preds_tensor():
  12. return paddle.to_tensor(np.array([[0.1, 0.7, 0.2], [0.3, 0.3, 0.4]]))
  13. @pytest.fixture
  14. def label_list():
  15. return {0: "class0", 1: "class1", 2: "class2"}
  16. # Parameterize tests to cover multiple scenarios
  17. @pytest.mark.parametrize(
  18. "label_list, expected",
  19. [
  20. ({0: "class0", 1: "class1", 2: "class2"}, [("class1", 0.7), ("class2", 0.4)]),
  21. (None, [(1, 0.7), (2, 0.4)]),
  22. ],
  23. )
  24. def test_cls_post_process_with_and_without_label_list(
  25. preds_tensor, label_list, expected
  26. ):
  27. post_process = ClsPostProcess(label_list=label_list)
  28. result = post_process(preds_tensor)
  29. assert isinstance(result, list), "Result should be a list"
  30. assert result == expected, f"Expected {expected}, got {result}"
  31. # Test with a key in the prediction dictionary
  32. def test_cls_post_process_with_key(preds_tensor, label_list):
  33. preds_dict = {"key": preds_tensor}
  34. post_process = ClsPostProcess(label_list=label_list, key="key")
  35. result = post_process(preds_dict)
  36. expected = [("class1", 0.7), ("class2", 0.4)]
  37. assert isinstance(result, list), "Result should be a list"
  38. assert result == expected, f"Expected {expected}, got {result}"
  39. # Test with label input
  40. def test_cls_post_process_with_label(preds_tensor, label_list):
  41. labels = [2, 0]
  42. post_process = ClsPostProcess(label_list=label_list)
  43. result, label_result = post_process(preds_tensor, labels)
  44. expected_result = [("class1", 0.7), ("class2", 0.4)]
  45. expected_label_result = [("class2", 1.0), ("class0", 1.0)]
  46. assert isinstance(result, list), "Result should be a list"
  47. assert result == expected_result, f"Expected {expected_result}, got {result}"
  48. assert isinstance(label_result, list), "Label result should be a list"
  49. assert (
  50. label_result == expected_label_result
  51. ), f"Expected {expected_label_result}, got {label_result}"