image_segmentation.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from typing import Any, Union, overload
  2. import numpy as np
  3. from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
  4. from .base import Pipeline, build_pipeline_init_args
  5. if is_vision_available():
  6. from PIL import Image
  7. from ..image_utils import load_image
  8. if is_torch_available():
  9. from ..models.auto.modeling_auto import (
  10. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
  11. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
  12. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  13. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
  14. )
  15. logger = logging.get_logger(__name__)
  16. @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
  17. class ImageSegmentationPipeline(Pipeline):
  18. """
  19. Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
  20. their classes.
  21. Example:
  22. ```python
  23. >>> from transformers import pipeline
  24. >>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
  25. >>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
  26. >>> len(segments)
  27. 2
  28. >>> segments[0]["label"]
  29. 'bird'
  30. >>> segments[1]["label"]
  31. 'bird'
  32. >>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
  33. <class 'PIL.Image.Image'>
  34. >>> segments[0]["mask"].size
  35. (768, 512)
  36. ```
  37. This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  38. `"image-segmentation"`.
  39. See the list of available models on
  40. [huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
  41. """
  42. _load_processor = False
  43. _load_image_processor = True
  44. _load_feature_extractor = False
  45. _load_tokenizer = None # Oneformer uses it but no-one else does
  46. def __init__(self, *args, **kwargs):
  47. super().__init__(*args, **kwargs)
  48. if self.framework == "tf":
  49. raise ValueError(f"The {self.__class__} is only available in PyTorch.")
  50. requires_backends(self, "vision")
  51. mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
  52. mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
  53. mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
  54. mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
  55. self.check_model_type(mapping)
  56. def _sanitize_parameters(self, **kwargs):
  57. preprocess_kwargs = {}
  58. postprocess_kwargs = {}
  59. if "subtask" in kwargs:
  60. postprocess_kwargs["subtask"] = kwargs["subtask"]
  61. preprocess_kwargs["subtask"] = kwargs["subtask"]
  62. if "threshold" in kwargs:
  63. postprocess_kwargs["threshold"] = kwargs["threshold"]
  64. if "mask_threshold" in kwargs:
  65. postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
  66. if "overlap_mask_area_threshold" in kwargs:
  67. postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
  68. if "timeout" in kwargs:
  69. preprocess_kwargs["timeout"] = kwargs["timeout"]
  70. return preprocess_kwargs, {}, postprocess_kwargs
  71. @overload
  72. def __call__(self, inputs: Union[str, "Image.Image"], **kwargs: Any) -> list[dict[str, Any]]: ...
  73. @overload
  74. def __call__(self, inputs: Union[list[str], list["Image.Image"]], **kwargs: Any) -> list[list[dict[str, Any]]]: ...
  75. def __call__(
  76. self, inputs: Union[str, "Image.Image", list[str], list["Image.Image"]], **kwargs: Any
  77. ) -> Union[list[dict[str, Any]], list[list[dict[str, Any]]]]:
  78. """
  79. Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
  80. Args:
  81. inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
  82. The pipeline handles three types of images:
  83. - A string containing an HTTP(S) link pointing to an image
  84. - A string containing a local path to an image
  85. - An image loaded in PIL directly
  86. The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
  87. same format: all as HTTP(S) links, all as local paths, or all as PIL images.
  88. subtask (`str`, *optional*):
  89. Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
  90. capabilities. If not set, the pipeline will attempt tp resolve in the following order:
  91. `panoptic`, `instance`, `semantic`.
  92. threshold (`float`, *optional*, defaults to 0.9):
  93. Probability threshold to filter out predicted masks.
  94. mask_threshold (`float`, *optional*, defaults to 0.5):
  95. Threshold to use when turning the predicted masks into binary values.
  96. overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
  97. Mask overlap threshold to eliminate small, disconnected segments.
  98. timeout (`float`, *optional*, defaults to None):
  99. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  100. the call may block forever.
  101. Return:
  102. If the input is a single image, will return a list of dictionaries, if the input is a list of several images,
  103. will return a list of list of dictionaries corresponding to each image.
  104. The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
  105. the following keys:
  106. - **label** (`str`) -- The class label identified by the model.
  107. - **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
  108. the original image. Returns a mask filled with zeros if no object is found.
  109. - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
  110. "object" described by the label and the mask.
  111. """
  112. # After deprecation of this is completed, remove the default `None` value for `images`
  113. if "images" in kwargs:
  114. inputs = kwargs.pop("images")
  115. if inputs is None:
  116. raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
  117. return super().__call__(inputs, **kwargs)
  118. def preprocess(self, image, subtask=None, timeout=None):
  119. image = load_image(image, timeout=timeout)
  120. target_size = [(image.height, image.width)]
  121. if self.model.config.__class__.__name__ == "OneFormerConfig":
  122. if subtask is None:
  123. kwargs = {}
  124. else:
  125. kwargs = {"task_inputs": [subtask]}
  126. inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
  127. if self.framework == "pt":
  128. inputs = inputs.to(self.dtype)
  129. inputs["task_inputs"] = self.tokenizer(
  130. inputs["task_inputs"],
  131. padding="max_length",
  132. max_length=self.model.config.task_seq_len,
  133. return_tensors=self.framework,
  134. )["input_ids"]
  135. else:
  136. inputs = self.image_processor(images=[image], return_tensors="pt")
  137. if self.framework == "pt":
  138. inputs = inputs.to(self.dtype)
  139. inputs["target_size"] = target_size
  140. return inputs
  141. def _forward(self, model_inputs):
  142. target_size = model_inputs.pop("target_size")
  143. model_outputs = self.model(**model_inputs)
  144. model_outputs["target_size"] = target_size
  145. return model_outputs
  146. def postprocess(
  147. self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
  148. ):
  149. fn = None
  150. if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
  151. fn = self.image_processor.post_process_panoptic_segmentation
  152. elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
  153. fn = self.image_processor.post_process_instance_segmentation
  154. if fn is not None:
  155. outputs = fn(
  156. model_outputs,
  157. threshold=threshold,
  158. mask_threshold=mask_threshold,
  159. overlap_mask_area_threshold=overlap_mask_area_threshold,
  160. target_sizes=model_outputs["target_size"],
  161. )[0]
  162. annotation = []
  163. segmentation = outputs["segmentation"]
  164. for segment in outputs["segments_info"]:
  165. mask = (segmentation == segment["id"]) * 255
  166. mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
  167. label = self.model.config.id2label[segment["label_id"]]
  168. score = segment["score"]
  169. annotation.append({"score": score, "label": label, "mask": mask})
  170. elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
  171. outputs = self.image_processor.post_process_semantic_segmentation(
  172. model_outputs, target_sizes=model_outputs["target_size"]
  173. )[0]
  174. annotation = []
  175. segmentation = outputs.numpy()
  176. labels = np.unique(segmentation)
  177. for label in labels:
  178. mask = (segmentation == label) * 255
  179. mask = Image.fromarray(mask.astype(np.uint8), mode="L")
  180. label = self.model.config.id2label[label]
  181. annotation.append({"score": None, "label": label, "mask": mask})
  182. else:
  183. raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
  184. return annotation