image.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. from typing import Any, Dict, Union
  4. import cv2
  5. import numpy as np
  6. import PIL
  7. from numpy import ndarray
  8. from PIL import Image, ImageOps
  9. from modelscope.fileio import File
  10. from modelscope.metainfo import Preprocessors
  11. from modelscope.pipeline_inputs import InputKeys
  12. from modelscope.utils.constant import Fields
  13. from modelscope.utils.type_assert import type_assert
  14. from .base import Preprocessor
  15. from .builder import PREPROCESSORS
  16. @PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image)
  17. class LoadImage:
  18. """Load an image from file or url.
  19. Added or updated keys are "filename", "img", "img_shape",
  20. "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
  21. "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
  22. Args:
  23. mode (str): See :ref:`PIL.Mode<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`.
  24. backend (str): Type of loading image. Should be: cv2 or pillow. Default is pillow.
  25. """
  26. def __init__(self, mode='rgb', backend='pillow'):
  27. self.mode = mode.upper()
  28. self.backend = backend
  29. def __call__(self, input: Union[str, Dict[str, str]]):
  30. """Call functions to load image and get image meta information.
  31. Args:
  32. input (str or dict): input image path or input dict with
  33. a key `filename`.
  34. Returns:
  35. dict: The dict contains loaded image.
  36. """
  37. if isinstance(input, dict):
  38. image_path_or_url = input['filename']
  39. else:
  40. image_path_or_url = input
  41. if self.backend == 'cv2':
  42. storage = File._get_storage(image_path_or_url)
  43. with storage.as_local_path(image_path_or_url) as img_path:
  44. img = cv2.imread(img_path, cv2.IMREAD_COLOR)
  45. if self.mode == 'RGB':
  46. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
  47. img_h, img_w, img_c = img.shape[0], img.shape[1], img.shape[2]
  48. img_shape = (img_h, img_w, img_c)
  49. elif self.backend == 'pillow':
  50. bytes = File.read(image_path_or_url)
  51. # TODO @wenmeng.zwm add opencv decode as optional
  52. # we should also look at the input format which is the most commonly
  53. # used in Mind' image related models
  54. with io.BytesIO(bytes) as infile:
  55. img = Image.open(infile)
  56. img = ImageOps.exif_transpose(img)
  57. img = img.convert(self.mode)
  58. img_shape = (img.size[1], img.size[0], 3)
  59. else:
  60. raise TypeError(f'backend should be either cv2 or pillow,'
  61. f'but got {self.backend}')
  62. results = {
  63. 'filename': image_path_or_url,
  64. 'img': img,
  65. 'img_shape': img_shape,
  66. 'img_field': 'img',
  67. }
  68. if isinstance(input, dict):
  69. input_ret = input.copy()
  70. input_ret.update(results)
  71. results = input_ret
  72. return results
  73. def __repr__(self):
  74. repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})'
  75. return repr_str
  76. @staticmethod
  77. def convert_to_ndarray(input) -> ndarray:
  78. if isinstance(input, str):
  79. img = np.array(load_image(input))
  80. elif isinstance(input, PIL.Image.Image):
  81. img = np.array(input.convert('RGB'))
  82. elif isinstance(input, np.ndarray):
  83. if len(input.shape) == 2:
  84. input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
  85. img = input[:, :, ::-1]
  86. elif isinstance(input, Dict):
  87. img = input.get(InputKeys.IMAGE, None)
  88. if img:
  89. img = np.array(load_image(img))
  90. else:
  91. raise TypeError(f'input should be either str, PIL.Image,'
  92. f' np.array, but got {type(input)}')
  93. return img
  94. @staticmethod
  95. def convert_to_img(input) -> ndarray:
  96. if isinstance(input, str):
  97. img = load_image(input)
  98. elif isinstance(input, PIL.Image.Image):
  99. img = input.convert('RGB')
  100. elif isinstance(input, np.ndarray):
  101. if len(input.shape) == 2:
  102. img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
  103. img = input[:, :, ::-1]
  104. img = Image.fromarray(img.astype('uint8')).convert('RGB')
  105. elif isinstance(input, Dict):
  106. img = input.get(InputKeys.IMAGE, None)
  107. if img:
  108. img = load_image(img)
  109. else:
  110. raise TypeError(f'input should be either str, PIL.Image,'
  111. f' np.array, but got {type(input)}')
  112. return img
  113. def load_image(image_path_or_url: str) -> Image.Image:
  114. """ simple interface to load an image from file or url
  115. Args:
  116. image_path_or_url (str): image file path or http url
  117. """
  118. loader = LoadImage()
  119. return loader(image_path_or_url)['img']
  120. @PREPROCESSORS.register_module(
  121. Fields.cv, module_name=Preprocessors.object_detection_tinynas_preprocessor)
  122. class ObjectDetectionTinynasPreprocessor(Preprocessor):
  123. def __init__(self, size_divisible=32, **kwargs):
  124. """Preprocess the image.
  125. What this preprocessor will do:
  126. 1. Transpose the image matrix to make the channel the first dim.
  127. 2. If the size_divisible is gt than 0, it will be used to pad the image.
  128. 3. Expand an extra image dim as dim 0.
  129. Args:
  130. size_divisible (int): The number will be used as a length unit to pad the image.
  131. Formula: int(math.ceil(shape / size_divisible) * size_divisible)
  132. Default 32.
  133. """
  134. super().__init__(**kwargs)
  135. self.size_divisible = size_divisible
  136. @type_assert(object, object)
  137. def __call__(self, data: np.ndarray) -> Dict[str, ndarray]:
  138. """Preprocess the image.
  139. Args:
  140. data: The input image with 3 dimensions.
  141. Returns:
  142. The processed data in dict.
  143. {'img': np.ndarray}
  144. """
  145. image = data.astype(np.float32)
  146. image = image.transpose((2, 0, 1))
  147. shape = image.shape # c, h, w
  148. if self.size_divisible > 0:
  149. import math
  150. stride = self.size_divisible
  151. shape = list(shape)
  152. shape[1] = int(math.ceil(shape[1] / stride) * stride)
  153. shape[2] = int(math.ceil(shape[2] / stride) * stride)
  154. shape = tuple(shape)
  155. pad_img = np.zeros(shape).astype(np.float32)
  156. pad_img[:, :image.shape[1], :image.shape[2]] = image
  157. pad_img = np.expand_dims(pad_img, 0)
  158. return {'img': pad_img}
  159. @PREPROCESSORS.register_module(
  160. Fields.cv, module_name=Preprocessors.image_color_enhance_preprocessor)
  161. class ImageColorEnhanceFinetunePreprocessor(Preprocessor):
  162. def __init__(self, model_dir: str, *args, **kwargs):
  163. """preprocess the data from the `model_dir` path
  164. Args:
  165. model_dir (str): model path
  166. """
  167. super().__init__(*args, **kwargs)
  168. self.model_dir: str = model_dir
  169. @type_assert(object, object)
  170. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  171. """process the raw input data
  172. Args:
  173. data (tuple): [sentence1, sentence2]
  174. sentence1 (str): a sentence
  175. Example:
  176. 'you are so handsome.'
  177. sentence2 (str): a sentence
  178. Example:
  179. 'you are so beautiful.'
  180. Returns:
  181. Dict[str, Any]: the preprocessed data
  182. """
  183. return data
  184. @PREPROCESSORS.register_module(
  185. Fields.cv, module_name=Preprocessors.image_denoise_preprocessor)
  186. class ImageDenoisePreprocessor(Preprocessor):
  187. def __init__(self, model_dir: str, *args, **kwargs):
  188. """
  189. Args:
  190. model_dir (str): model path
  191. """
  192. super().__init__(*args, **kwargs)
  193. self.model_dir: str = model_dir
  194. from .common import Filter
  195. # TODO: `Filter` should be moved to configurarion file of each model
  196. self._transforms = [Filter(reserved_keys=['input', 'target'])]
  197. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  198. """process the raw input data
  199. Args:
  200. data Dict[str, Any]
  201. Returns:
  202. Dict[str, Any]: the preprocessed data
  203. """
  204. for t in self._transforms:
  205. data = t(data)
  206. return data
  207. @PREPROCESSORS.register_module(
  208. Fields.cv, module_name=Preprocessors.image_deblur_preprocessor)
  209. class ImageDeblurPreprocessor(Preprocessor):
  210. def __init__(self, model_dir: str, *args, **kwargs):
  211. """
  212. Args:
  213. model_dir (str): model path
  214. """
  215. super().__init__(*args, **kwargs)
  216. self.model_dir: str = model_dir
  217. from .common import Filter
  218. # TODO: `Filter` should be moved to configurarion file of each model
  219. self._transforms = [Filter(reserved_keys=['input', 'target'])]
  220. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  221. """process the raw input data
  222. Args:
  223. data Dict[str, Any]
  224. Returns:
  225. Dict[str, Any]: the preprocessed data
  226. """
  227. for t in self._transforms:
  228. data = t(data)
  229. return data
  230. @PREPROCESSORS.register_module(
  231. Fields.cv,
  232. module_name=Preprocessors.image_portrait_enhancement_preprocessor)
  233. class ImagePortraitEnhancementPreprocessor(Preprocessor):
  234. def __init__(self, model_dir: str, *args, **kwargs):
  235. """
  236. Args:
  237. model_dir (str): model path
  238. """
  239. super().__init__(*args, **kwargs)
  240. self.model_dir: str = model_dir
  241. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  242. """process the raw input data
  243. Args:
  244. data Dict[str, Any]
  245. Returns:
  246. Dict[str, Any]: the preprocessed data
  247. """
  248. return data
  249. @PREPROCESSORS.register_module(
  250. Fields.cv,
  251. module_name=Preprocessors.image_instance_segmentation_preprocessor)
  252. class ImageInstanceSegmentationPreprocessor(Preprocessor):
  253. def __init__(self, *args, **kwargs):
  254. """image instance segmentation preprocessor in the fine-tune scenario
  255. """
  256. super().__init__(*args, **kwargs)
  257. self.training = kwargs.pop('training', True)
  258. self.preprocessor_train_cfg = kwargs.pop('train', None)
  259. self.preprocessor_test_cfg = kwargs.pop('val', None)
  260. self.train_transforms = []
  261. self.test_transforms = []
  262. from modelscope.models.cv.image_instance_segmentation.datasets import \
  263. build_preprocess_transform
  264. if self.preprocessor_train_cfg is not None:
  265. if isinstance(self.preprocessor_train_cfg, dict):
  266. self.preprocessor_train_cfg = [self.preprocessor_train_cfg]
  267. for cfg in self.preprocessor_train_cfg:
  268. transform = build_preprocess_transform(cfg)
  269. self.train_transforms.append(transform)
  270. if self.preprocessor_test_cfg is not None:
  271. if isinstance(self.preprocessor_test_cfg, dict):
  272. self.preprocessor_test_cfg = [self.preprocessor_test_cfg]
  273. for cfg in self.preprocessor_test_cfg:
  274. transform = build_preprocess_transform(cfg)
  275. self.test_transforms.append(transform)
  276. def train(self):
  277. self.training = True
  278. return
  279. def eval(self):
  280. self.training = False
  281. return
  282. @type_assert(object, object)
  283. def __call__(self, results: Dict[str, Any]):
  284. """process the raw input data
  285. Args:
  286. results (dict): Result dict from loading pipeline.
  287. Returns:
  288. Dict[str, Any] | None: the preprocessed data
  289. """
  290. if self.training:
  291. transforms = self.train_transforms
  292. else:
  293. transforms = self.test_transforms
  294. for t in transforms:
  295. results = t(results)
  296. if results is None:
  297. return None
  298. return results
  299. @PREPROCESSORS.register_module(
  300. Fields.cv, module_name=Preprocessors.video_summarization_preprocessor)
  301. class VideoSummarizationPreprocessor(Preprocessor):
  302. def __init__(self, model_dir: str, *args, **kwargs):
  303. """
  304. Args:
  305. model_dir (str): model path
  306. """
  307. super().__init__(*args, **kwargs)
  308. self.model_dir: str = model_dir
  309. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  310. """process the raw input data
  311. Args:
  312. data Dict[str, Any]
  313. Returns:
  314. Dict[str, Any]: the preprocessed data
  315. """
  316. return data
  317. @PREPROCESSORS.register_module(
  318. Fields.cv,
  319. module_name=Preprocessors.image_classification_bypass_preprocessor)
  320. class ImageClassificationBypassPreprocessor(Preprocessor):
  321. def __init__(self, *args, **kwargs):
  322. """image classification bypass preprocessor in the fine-tune scenario
  323. """
  324. super().__init__(*args, **kwargs)
  325. self.training = kwargs.pop('training', True)
  326. self.preprocessor_train_cfg = kwargs.pop('train', None)
  327. self.preprocessor_val_cfg = kwargs.pop('val', None)
  328. def train(self):
  329. self.training = True
  330. return
  331. def eval(self):
  332. self.training = False
  333. return
  334. def __call__(self, results: Dict[str, Any]):
  335. """process the raw input data
  336. Args:
  337. results (dict): Result dict from loading pipeline.
  338. Returns:
  339. Dict[str, Any] | None: the preprocessed data
  340. """
  341. pass