iaa_augment.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
  17. """
  18. import os
  19. # Prevent automatic updates in Albumentations for stability in augmentation behavior
  20. os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
  21. import numpy as np
  22. import albumentations as A
  23. from albumentations.core.transforms_interface import DualTransform
  24. from albumentations.augmentations.geometric import functional as fgeometric
  25. from packaging import version
  26. ALBU_VERSION = version.parse(A.__version__)
  27. IS_ALBU_NEW_VERSION = ALBU_VERSION >= version.parse("1.4.15")
  28. # Custom resize transformation mimicking Imgaug's behavior with scaling
  29. class ImgaugLikeResize(DualTransform):
  30. def __init__(self, scale_range=(0.5, 3.0), interpolation=1, p=1.0):
  31. super(ImgaugLikeResize, self).__init__(p)
  32. self.scale_range = scale_range
  33. self.interpolation = interpolation
  34. # Resize the image based on a randomly chosen scale within the scale range
  35. def apply(self, img, scale=1.0, **params):
  36. height, width = img.shape[:2]
  37. new_height = int(height * scale)
  38. new_width = int(width * scale)
  39. if IS_ALBU_NEW_VERSION:
  40. return fgeometric.resize(
  41. img, (new_height, new_width), interpolation=self.interpolation
  42. )
  43. return fgeometric.resize(
  44. img, new_height, new_width, interpolation=self.interpolation
  45. )
  46. # Apply the same scaling transformation to keypoints (e.g., polygon points)
  47. def apply_to_keypoints(self, keypoints, scale=1.0, **params):
  48. return np.array(
  49. [(x * scale, y * scale) + tuple(rest) for x, y, *rest in keypoints]
  50. )
  51. # Get random scale parameter within the specified range
  52. def get_params(self):
  53. scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
  54. return {"scale": scale}
  55. # Builder class to translate custom augmenter arguments into Albumentations-compatible format
  56. class AugmenterBuilder(object):
  57. def __init__(self):
  58. # Map common Imgaug transformations to equivalent Albumentations transforms
  59. self.imgaug_to_albu = {
  60. "Fliplr": "HorizontalFlip",
  61. "Flipud": "VerticalFlip",
  62. "Affine": "Affine",
  63. # Additional mappings can be added here if needed
  64. }
  65. # Recursive method to construct augmentation pipeline based on provided arguments
  66. def build(self, args, root=True):
  67. if args is None or len(args) == 0:
  68. return None
  69. elif isinstance(args, list):
  70. # Build the full augmentation sequence if it's a root-level call
  71. if root:
  72. sequence = [self.build(value, root=False) for value in args]
  73. return A.Compose(
  74. sequence,
  75. keypoint_params=A.KeypointParams(
  76. format="xy", remove_invisible=False
  77. ),
  78. )
  79. else:
  80. # Build individual augmenters for nested arguments
  81. augmenter_type = args[0]
  82. augmenter_args = args[1] if len(args) > 1 else {}
  83. augmenter_args_mapped = self.map_arguments(
  84. augmenter_type, augmenter_args
  85. )
  86. augmenter_type_mapped = self.imgaug_to_albu.get(
  87. augmenter_type, augmenter_type
  88. )
  89. if augmenter_type_mapped == "Resize":
  90. return ImgaugLikeResize(**augmenter_args_mapped)
  91. else:
  92. cls = getattr(A, augmenter_type_mapped)
  93. return cls(
  94. **{
  95. k: self.to_tuple_if_list(v)
  96. for k, v in augmenter_args_mapped.items()
  97. }
  98. )
  99. elif isinstance(args, dict):
  100. # Process individual transformation specified as dictionary
  101. augmenter_type = args["type"]
  102. augmenter_args = args.get("args", {})
  103. augmenter_args_mapped = self.map_arguments(augmenter_type, augmenter_args)
  104. augmenter_type_mapped = self.imgaug_to_albu.get(
  105. augmenter_type, augmenter_type
  106. )
  107. if augmenter_type_mapped == "Resize":
  108. return ImgaugLikeResize(**augmenter_args_mapped)
  109. else:
  110. cls = getattr(A, augmenter_type_mapped)
  111. return cls(
  112. **{
  113. k: self.to_tuple_if_list(v)
  114. for k, v in augmenter_args_mapped.items()
  115. }
  116. )
  117. else:
  118. raise RuntimeError("Unknown augmenter arg: " + str(args))
  119. # Map arguments to expected format for each augmenter type
  120. def map_arguments(self, augmenter_type, augmenter_args):
  121. augmenter_args = augmenter_args.copy() # Avoid modifying the original arguments
  122. if augmenter_type == "Resize":
  123. # Ensure size is a valid 2-element list or tuple
  124. size = augmenter_args.get("size")
  125. if size:
  126. if not isinstance(size, (list, tuple)) or len(size) != 2:
  127. raise ValueError(
  128. f"'size' must be a list or tuple of two numbers, but got {size}"
  129. )
  130. min_scale, max_scale = size
  131. return {
  132. "scale_range": (min_scale, max_scale),
  133. "interpolation": 1, # Linear interpolation
  134. "p": 1.0,
  135. }
  136. else:
  137. return {"scale_range": (1.0, 1.0), "interpolation": 1, "p": 1.0}
  138. elif augmenter_type == "Affine":
  139. # Map rotation to a tuple and ensure p=1.0 to apply transformation
  140. rotate = augmenter_args.get("rotate", 0)
  141. if isinstance(rotate, list):
  142. rotate = tuple(rotate)
  143. elif isinstance(rotate, (int, float)):
  144. rotate = (float(rotate), float(rotate))
  145. augmenter_args["rotate"] = rotate
  146. augmenter_args["p"] = 1.0
  147. return augmenter_args
  148. else:
  149. # For other augmenters, ensure 'p' probability is specified
  150. p = augmenter_args.get("p", 1.0)
  151. augmenter_args["p"] = p
  152. return augmenter_args
  153. # Convert lists to tuples for Albumentations compatibility
  154. def to_tuple_if_list(self, obj):
  155. if isinstance(obj, list):
  156. return tuple(obj)
  157. return obj
  158. # Wrapper class for image and polygon transformations using Imgaug-style augmentation
  159. class IaaAugment:
  160. def __init__(self, augmenter_args=None, **kwargs):
  161. if augmenter_args is None:
  162. # Default augmenters if none are specified
  163. augmenter_args = [
  164. {"type": "Fliplr", "args": {"p": 0.5}},
  165. {"type": "Affine", "args": {"rotate": [-10, 10]}},
  166. {"type": "Resize", "args": {"size": [0.5, 3]}},
  167. ]
  168. self.augmenter = AugmenterBuilder().build(augmenter_args)
  169. # Apply the augmentations to image and polygon data
  170. def __call__(self, data):
  171. image = data["image"]
  172. if self.augmenter:
  173. # Flatten polygons to individual keypoints for transformation
  174. keypoints = []
  175. keypoints_lengths = []
  176. for poly in data["polys"]:
  177. keypoints.extend([tuple(point) for point in poly])
  178. keypoints_lengths.append(len(poly))
  179. # Apply the augmentation pipeline to image and keypoints
  180. transformed = self.augmenter(image=image, keypoints=keypoints)
  181. data["image"] = transformed["image"]
  182. # Extract transformed keypoints and reconstruct polygon structures
  183. transformed_keypoints = transformed["keypoints"]
  184. # Reassemble polygons from transformed keypoints
  185. new_polys = []
  186. idx = 0
  187. for length in keypoints_lengths:
  188. new_poly = transformed_keypoints[idx : idx + length]
  189. new_polys.append(np.array([kp[:2] for kp in new_poly]))
  190. idx += length
  191. data["polys"] = np.array(new_polys)
  192. return data