test_iaa_augment.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. import sys
  3. import pytest
  4. import numpy as np
  5. import random
  6. current_dir = os.path.dirname(os.path.abspath(__file__))
  7. sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
  8. from ppocr.data.imaug.iaa_augment import IaaAugment
  9. # Set a fixed random seed to ensure test reproducibility
  10. np.random.seed(42)
  11. random.seed(42)
  12. # Fixture to provide a sample image for tests
  13. @pytest.fixture
  14. def sample_image():
  15. # Create a 100x100 pixel dummy image with 3 color channels (RGB)
  16. return np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  17. # Fixture to provide sample polygons for tests
  18. @pytest.fixture
  19. def sample_polys():
  20. # Create dummy polygons as sample data
  21. polys = [
  22. np.array([[10, 10], [20, 10], [20, 20], [10, 20]], dtype=np.float32),
  23. np.array([[30, 30], [40, 30], [40, 40], [30, 40]], dtype=np.float32),
  24. ]
  25. return polys
  26. # Helper function to create a data dictionary for testing
  27. def create_data(sample_image, sample_polys):
  28. return {
  29. "image": sample_image.copy(),
  30. "polys": [poly.copy() for poly in sample_polys],
  31. }
  32. # Test the default behavior of the augmenter (without specified arguments)
  33. def test_iaa_augment_default(sample_image, sample_polys):
  34. data = create_data(sample_image, sample_polys)
  35. augmenter = IaaAugment()
  36. transformed_data = augmenter(data)
  37. # Check the data types and structure of the transformed image and polygons
  38. assert isinstance(
  39. transformed_data["image"], np.ndarray
  40. ), "Image should be a numpy array"
  41. assert isinstance(
  42. transformed_data["polys"], np.ndarray
  43. ), "Polys should be a numpy array"
  44. assert transformed_data["image"].ndim == 3, "Image should be 3-dimensional"
  45. # Verify that the polygons have been transformed
  46. polys_changed = any(
  47. not np.allclose(orig_poly, trans_poly)
  48. for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
  49. )
  50. assert polys_changed, "Polygons should have been transformed"
  51. # Test the augmenter with empty arguments, meaning no transformations should occur
  52. def test_iaa_augment_none(sample_image, sample_polys):
  53. data = create_data(sample_image, sample_polys)
  54. augmenter = IaaAugment(augmenter_args=[])
  55. transformed_data = augmenter(data)
  56. # Check that the image and polygons remain unchanged
  57. assert np.array_equal(
  58. data["image"], transformed_data["image"]
  59. ), "Image should be unchanged"
  60. for orig_poly, transformed_poly in zip(data["polys"], transformed_data["polys"]):
  61. assert np.array_equal(
  62. orig_poly, transformed_poly
  63. ), "Polygons should be unchanged"
  64. # Parameterized test to check various augmenter arguments and expected image shapes
  65. @pytest.mark.parametrize(
  66. "augmenter_args, expected_shape",
  67. [
  68. ([], (100, 100, 3)),
  69. ([{"type": "Resize", "args": {"size": [0.5, 0.5]}}], (50, 50, 3)),
  70. ([{"type": "Resize", "args": {"size": [2.0, 2.0]}}], (200, 200, 3)),
  71. ],
  72. )
  73. def test_iaa_augment_resize(sample_image, sample_polys, augmenter_args, expected_shape):
  74. data = create_data(sample_image, sample_polys)
  75. augmenter = IaaAugment(augmenter_args=augmenter_args)
  76. transformed_data = augmenter(data)
  77. # Verify that the transformed image has the expected shape
  78. assert (
  79. transformed_data["image"].shape == expected_shape
  80. ), f"Expected image shape {expected_shape}, got {transformed_data['image'].shape}"
  81. # Test custom augmenter arguments with specific transformations
  82. def test_iaa_augment_custom(sample_image, sample_polys):
  83. data = create_data(sample_image, sample_polys)
  84. augmenter_args = [
  85. {"type": "Affine", "args": {"rotate": [45, 45]}}, # Apply 45-degree rotation
  86. {"type": "Resize", "args": {"size": [0.5, 0.5]}},
  87. ]
  88. augmenter = IaaAugment(augmenter_args=augmenter_args)
  89. transformed_data = augmenter(data)
  90. # Check the expected image dimensions after resizing
  91. expected_height = int(sample_image.shape[0] * 0.5)
  92. expected_width = int(sample_image.shape[1] * 0.5)
  93. assert (
  94. transformed_data["image"].shape[0] == expected_height
  95. ), "Image height should be scaled by 0.5"
  96. assert (
  97. transformed_data["image"].shape[1] == expected_width
  98. ), "Image width should be scaled by 0.5"
  99. # Verify that the polygons have been transformed
  100. polys_changed = any(
  101. not np.allclose(orig_poly, trans_poly)
  102. for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
  103. )
  104. assert polys_changed, "Polygons should have been transformed"
  105. # Test that an unknown transformation type raises an AttributeError
  106. def test_iaa_augment_unknown_transform():
  107. augmenter_args = [{"type": "UnknownTransform", "args": {}}]
  108. with pytest.raises(AttributeError):
  109. IaaAugment(augmenter_args=augmenter_args)
  110. # Test that an invalid resize size parameter raises a ValueError
  111. def test_iaa_augment_invalid_resize_size(sample_image, sample_polys):
  112. augmenter_args = [{"type": "Resize", "args": {"size": "invalid_size"}}]
  113. with pytest.raises(ValueError) as exc_info:
  114. IaaAugment(augmenter_args=augmenter_args)
  115. assert "'size' must be a list or tuple of two numbers" in str(exc_info.value)
  116. # Test that polygons are transformed as expected
  117. def test_iaa_augment_polys_transformation(sample_image, sample_polys):
  118. data = create_data(sample_image, sample_polys)
  119. augmenter_args = [
  120. {"type": "Affine", "args": {"rotate": [90, 90]}}, # Apply 90-degree rotation
  121. ]
  122. augmenter = IaaAugment(augmenter_args=augmenter_args)
  123. transformed_data = augmenter(data)
  124. # Verify that the polygons have been transformed
  125. polys_changed = any(
  126. not np.allclose(orig_poly, trans_poly)
  127. for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
  128. )
  129. assert polys_changed, "Polygons should have been transformed"
  130. # Test multiple transformations applied to the augmenter
  131. def test_iaa_augment_multiple_transforms(sample_image, sample_polys):
  132. augmenter_args = [
  133. {"type": "Fliplr", "args": {"p": 1.0}}, # Always apply horizontal flip
  134. {"type": "Affine", "args": {"shear": 10}},
  135. ]
  136. data = create_data(sample_image, sample_polys)
  137. augmenter = IaaAugment(augmenter_args=augmenter_args)
  138. transformed_data = augmenter(data)
  139. # Ensure the image has been transformed
  140. images_different = not np.array_equal(transformed_data["image"], sample_image)
  141. assert images_different, "Image should be transformed"
  142. # Ensure the polygons have been transformed
  143. polys_changed = any(
  144. not np.allclose(orig_poly, trans_poly)
  145. for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
  146. )
  147. assert polys_changed, "Polygons should have been transformed"