image.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from PIL import Image
  15. from paddle.utils import try_import
  16. __all__ = []
  17. _image_backend = 'pil'
  18. def set_image_backend(backend):
  19. """
  20. Specifies the backend used to load images in class :ref:`api_paddle_datasets_ImageFolder`
  21. and :ref:`api_paddle_datasets_DatasetFolder` . Now support backends are pillow and opencv.
  22. If backend not set, will use 'pil' as default.
  23. Args:
  24. backend (str): Name of the image load backend, should be one of {'pil', 'cv2'}.
  25. Examples:
  26. .. code-block:: python
  27. >>> import os
  28. >>> import shutil
  29. >>> import tempfile
  30. >>> import numpy as np
  31. >>> from PIL import Image
  32. >>> from paddle.vision import DatasetFolder
  33. >>> from paddle.vision import set_image_backend
  34. >>> set_image_backend('pil')
  35. >>> def make_fake_dir():
  36. ... data_dir = tempfile.mkdtemp()
  37. ...
  38. ... for i in range(2):
  39. ... sub_dir = os.path.join(data_dir, 'class_' + str(i))
  40. ... if not os.path.exists(sub_dir):
  41. ... os.makedirs(sub_dir)
  42. ... for j in range(2):
  43. ... fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
  44. ... fake_img.save(os.path.join(sub_dir, str(j) + '.png'))
  45. ... return data_dir
  46. >>> temp_dir = make_fake_dir()
  47. >>> pil_data_folder = DatasetFolder(temp_dir)
  48. >>> for items in pil_data_folder:
  49. ... break
  50. >>> print(type(items[0]))
  51. <class 'PIL.Image.Image'>
  52. >>> # use opencv as backend
  53. >>> set_image_backend('cv2')
  54. >>> cv2_data_folder = DatasetFolder(temp_dir)
  55. >>> for items in cv2_data_folder:
  56. ... break
  57. >>> print(type(items[0]))
  58. <class 'numpy.ndarray'>
  59. >>> shutil.rmtree(temp_dir)
  60. """
  61. global _image_backend
  62. if backend not in ['pil', 'cv2', 'tensor']:
  63. raise ValueError(
  64. f"Expected backend are one of ['pil', 'cv2', 'tensor'], but got {backend}"
  65. )
  66. _image_backend = backend
  67. def get_image_backend():
  68. """
  69. Gets the name of the package used to load images
  70. Returns:
  71. str: backend of image load.
  72. Examples:
  73. .. code-block:: python
  74. >>> from paddle.vision import get_image_backend
  75. >>> backend = get_image_backend()
  76. >>> print(backend)
  77. pil
  78. """
  79. return _image_backend
  80. def image_load(path, backend=None):
  81. """Load an image.
  82. Args:
  83. path (str): Path of the image.
  84. backend (str, optional): The image decoding backend type. Options are
  85. `cv2`, `pil`, `None`. If backend is None, the global _imread_backend
  86. specified by :ref:`api_paddle_vision_set_image_backend` will be used. Default: None.
  87. Returns:
  88. PIL.Image or np.array: Loaded image.
  89. Examples:
  90. .. code-block:: python
  91. >>> import numpy as np
  92. >>> from PIL import Image
  93. >>> from paddle.vision import image_load, set_image_backend
  94. >>> fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
  95. >>> path = 'temp.png'
  96. >>> fake_img.save(path)
  97. >>> set_image_backend('pil')
  98. >>> pil_img = image_load(path).convert('RGB')
  99. >>> print(type(pil_img))
  100. <class 'PIL.Image.Image'>
  101. >>> # use opencv as backend
  102. >>> set_image_backend('cv2')
  103. >>> np_img = image_load(path)
  104. >>> print(type(np_img))
  105. <class 'numpy.ndarray'>
  106. """
  107. if backend is None:
  108. backend = _image_backend
  109. if backend not in ['pil', 'cv2', 'tensor']:
  110. raise ValueError(
  111. f"Expected backend are one of ['pil', 'cv2', 'tensor'], but got {backend}"
  112. )
  113. if backend == 'pil':
  114. return Image.open(path)
  115. elif backend == 'cv2':
  116. cv2 = try_import('cv2')
  117. return cv2.imread(path)