image_to_image.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, overload
  15. import numpy as np
  16. from ..utils import (
  17. add_end_docstrings,
  18. is_torch_available,
  19. is_vision_available,
  20. logging,
  21. requires_backends,
  22. )
  23. from .base import Pipeline, build_pipeline_init_args
  24. if is_vision_available():
  25. from PIL import Image
  26. from ..image_utils import load_image
  27. if is_torch_available():
  28. from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES
  29. logger = logging.get_logger(__name__)
  30. @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
  31. class ImageToImagePipeline(Pipeline):
  32. """
  33. Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous
  34. image input.
  35. Example:
  36. ```python
  37. >>> from PIL import Image
  38. >>> import requests
  39. >>> from transformers import pipeline
  40. >>> upscaler = pipeline("image-to-image", model="caidas/swin2SR-classical-sr-x2-64")
  41. >>> img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
  42. >>> img = img.resize((64, 64))
  43. >>> upscaled_img = upscaler(img)
  44. >>> img.size
  45. (64, 64)
  46. >>> upscaled_img.size
  47. (144, 144)
  48. ```
  49. This image to image pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  50. `"image-to-image"`.
  51. See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=image-to-image).
  52. """
  53. _load_processor = False
  54. _load_image_processor = True
  55. _load_feature_extractor = False
  56. _load_tokenizer = False
  57. def __init__(self, *args, **kwargs):
  58. super().__init__(*args, **kwargs)
  59. requires_backends(self, "vision")
  60. self.check_model_type(MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
  61. def _sanitize_parameters(self, **kwargs):
  62. preprocess_params = {}
  63. postprocess_params = {}
  64. forward_params = {}
  65. if "timeout" in kwargs:
  66. preprocess_params["timeout"] = kwargs["timeout"]
  67. if "head_mask" in kwargs:
  68. forward_params["head_mask"] = kwargs["head_mask"]
  69. return preprocess_params, forward_params, postprocess_params
  70. @overload
  71. def __call__(self, images: Union[str, "Image.Image"], **kwargs: Any) -> "Image.Image": ...
  72. @overload
  73. def __call__(self, images: Union[list[str], list["Image.Image"]], **kwargs: Any) -> list["Image.Image"]: ...
  74. def __call__(
  75. self, images: Union[str, list[str], "Image.Image", list["Image.Image"]], **kwargs: Any
  76. ) -> Union["Image.Image", list["Image.Image"]]:
  77. """
  78. Transform the image(s) passed as inputs.
  79. Args:
  80. images (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
  81. The pipeline handles three types of images:
  82. - A string containing a http link pointing to an image
  83. - A string containing a local path to an image
  84. - An image loaded in PIL directly
  85. The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
  86. Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
  87. images.
  88. timeout (`float`, *optional*, defaults to None):
  89. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
  90. the call may block forever.
  91. Return:
  92. An image (Image.Image) or a list of images (list["Image.Image"]) containing result(s). If the input is a
  93. single image, the return will be also a single image, if the input is a list of several images, it will
  94. return a list of transformed images.
  95. """
  96. return super().__call__(images, **kwargs)
  97. def _forward(self, model_inputs):
  98. model_outputs = self.model(**model_inputs)
  99. return model_outputs
  100. def preprocess(self, image, timeout=None):
  101. image = load_image(image, timeout=timeout)
  102. inputs = self.image_processor(images=[image], return_tensors="pt")
  103. if self.framework == "pt":
  104. inputs = inputs.to(self.dtype)
  105. return inputs
  106. def postprocess(self, model_outputs):
  107. images = []
  108. if "reconstruction" in model_outputs:
  109. outputs = model_outputs.reconstruction
  110. for output in outputs:
  111. output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  112. output = np.moveaxis(output, source=0, destination=-1)
  113. output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
  114. images.append(Image.fromarray(output))
  115. return images if len(images) > 1 else images[0]