mask_generation.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. from collections import defaultdict
  2. from typing import TYPE_CHECKING, Any, Optional, Union, overload
  3. from ..image_utils import load_image
  4. from ..utils import (
  5. add_end_docstrings,
  6. is_torch_available,
  7. logging,
  8. requires_backends,
  9. )
  10. from .base import ChunkPipeline, build_pipeline_init_args
  11. if is_torch_available():
  12. import torch
  13. from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
  14. if TYPE_CHECKING:
  15. from PIL import Image
  16. logger = logging.get_logger(__name__)
  17. @add_end_docstrings(
  18. build_pipeline_init_args(has_image_processor=True),
  19. r"""
  20. points_per_batch (*optional*, int, default to 64):
  21. Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
  22. memory.
  23. output_bboxes_mask (`bool`, *optional*, default to `False`):
  24. Whether or not to output the bounding box predictions.
  25. output_rle_masks (`bool`, *optional*, default to `False`):
  26. Whether or not to output the masks in `RLE` format""",
  27. )
  28. class MaskGenerationPipeline(ChunkPipeline):
  29. """
  30. Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
  31. image, given an image. It is a `ChunkPipeline` because you can separate the points in a mini-batch in order to
  32. avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
  33. same time. Default is `64`.
  34. The pipeline works in 3 steps:
  35. 1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
  36. labels.
  37. For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
  38. function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
  39. `points_per_batch`.
  40. 2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
  41. Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
  42. tensors and models are on the same device.
  43. 3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
  44. are induced:
  45. - image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
  46. resizes them according
  47. to the image size, and transforms there to binary masks.
  48. - image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
  49. `stability_scores`. Also
  50. applies a variety of filters based on non maximum suppression to remove bad masks.
  51. - image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
  52. Example:
  53. ```python
  54. >>> from transformers import pipeline
  55. >>> generator = pipeline(model="facebook/sam-vit-base", task="mask-generation")
  56. >>> outputs = generator(
  57. ... "http://images.cocodataset.org/val2017/000000039769.jpg",
  58. ... )
  59. >>> outputs = generator(
  60. ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
  61. ... )
  62. ```
  63. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  64. This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  65. `"mask-generation"`.
  66. See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
  67. """
  68. _load_processor = False
  69. _load_image_processor = True
  70. _load_feature_extractor = False
  71. _load_tokenizer = False
  72. def __init__(self, **kwargs):
  73. super().__init__(**kwargs)
  74. requires_backends(self, "vision")
  75. requires_backends(self, "torch")
  76. if self.framework != "pt":
  77. raise ValueError(f"The {self.__class__} is only available in PyTorch.")
  78. self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
  79. def _sanitize_parameters(self, **kwargs):
  80. preprocess_kwargs = {}
  81. postprocess_kwargs = {}
  82. forward_params = {}
  83. # preprocess args
  84. if "points_per_batch" in kwargs:
  85. preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
  86. if "points_per_crop" in kwargs:
  87. preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
  88. if "crops_n_layers" in kwargs:
  89. preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
  90. if "crop_overlap_ratio" in kwargs:
  91. preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
  92. if "crop_n_points_downscale_factor" in kwargs:
  93. preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
  94. if "timeout" in kwargs:
  95. preprocess_kwargs["timeout"] = kwargs["timeout"]
  96. # postprocess args
  97. if "pred_iou_thresh" in kwargs:
  98. forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
  99. if "stability_score_offset" in kwargs:
  100. forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
  101. if "mask_threshold" in kwargs:
  102. forward_params["mask_threshold"] = kwargs["mask_threshold"]
  103. if "stability_score_thresh" in kwargs:
  104. forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
  105. if "max_hole_area" in kwargs:
  106. forward_params["max_hole_area"] = kwargs["max_hole_area"]
  107. if "max_sprinkle_area" in kwargs:
  108. forward_params["max_sprinkle_area"] = kwargs["max_sprinkle_area"]
  109. if "crops_nms_thresh" in kwargs:
  110. postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
  111. if "output_rle_mask" in kwargs:
  112. postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
  113. if "output_bboxes_mask" in kwargs:
  114. postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
  115. return preprocess_kwargs, forward_params, postprocess_kwargs
  116. @overload
  117. def __call__(self, image: Union[str, "Image.Image"], *args: Any, **kwargs: Any) -> dict[str, Any]: ...
  118. @overload
  119. def __call__(
  120. self, image: Union[list[str], list["Image.Image"]], *args: Any, **kwargs: Any
  121. ) -> list[dict[str, Any]]: ...
  122. def __call__(
  123. self, image: Union[str, "Image.Image", list[str], list["Image.Image"]], *args: Any, **kwargs: Any
  124. ) -> Union[dict[str, Any], list[dict[str, Any]]]:
  125. """
  126. Generates binary segmentation masks
  127. Args:
  128. image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
  129. Image or list of images.
  130. mask_threshold (`float`, *optional*, defaults to 0.0):
  131. Threshold to use when turning the predicted masks into binary values.
  132. pred_iou_thresh (`float`, *optional*, defaults to 0.88):
  133. A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
  134. stability_score_thresh (`float`, *optional*, defaults to 0.95):
  135. A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
  136. binarize the model's mask predictions.
  137. stability_score_offset (`int`, *optional*, defaults to 1):
  138. The amount to shift the cutoff when calculated the stability score.
  139. crops_nms_thresh (`float`, *optional*, defaults to 0.7):
  140. The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
  141. crops_n_layers (`int`, *optional*, defaults to 0):
  142. If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
  143. layers to run, where each layer has 2**i_layer number of image crops.
  144. crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
  145. Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
  146. the image length. Later layers with more crops scale down this overlap.
  147. crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
  148. The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
  149. timeout (`float`, *optional*, defaults to None):
  150. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  151. the call may block forever.
  152. Return:
  153. `Dict`: A dictionary with the following keys:
  154. - **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
  155. height)` of the original image. Returns a mask filled with zeros if no object is found.
  156. - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
  157. the "object" described by the label and the mask.
  158. """
  159. num_workers = kwargs.pop("num_workers", None)
  160. batch_size = kwargs.pop("batch_size", None)
  161. return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
  162. def preprocess(
  163. self,
  164. image,
  165. points_per_batch=64,
  166. crops_n_layers: int = 0,
  167. crop_overlap_ratio: float = 512 / 1500,
  168. points_per_crop: int = 32,
  169. crop_n_points_downscale_factor: int = 1,
  170. timeout: Optional[float] = None,
  171. ):
  172. image = load_image(image, timeout=timeout)
  173. target_size = self.image_processor.size.get("longest_edge", self.image_processor.size.get("height"))
  174. crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
  175. image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
  176. )
  177. model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
  178. if self.framework == "pt":
  179. model_inputs = model_inputs.to(self.dtype)
  180. with self.device_placement():
  181. if self.framework == "pt":
  182. inference_context = self.get_inference_context()
  183. with inference_context():
  184. model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
  185. embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
  186. # Handle both SAM (single tensor) and SAM-HQ (tuple) outputs
  187. if isinstance(embeddings, tuple):
  188. image_embeddings, intermediate_embeddings = embeddings
  189. model_inputs["intermediate_embeddings"] = intermediate_embeddings
  190. else:
  191. image_embeddings = embeddings
  192. # TODO: Identifying the model by the type of its returned embeddings is brittle.
  193. # Consider using a more robust method for distinguishing model types here.
  194. model_inputs["image_embeddings"] = image_embeddings
  195. n_points = grid_points.shape[1]
  196. points_per_batch = points_per_batch if points_per_batch is not None else n_points
  197. if points_per_batch <= 0:
  198. raise ValueError(
  199. "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
  200. "To return all points at once, set points_per_batch to None"
  201. )
  202. for i in range(0, n_points, points_per_batch):
  203. batched_points = grid_points[:, i : i + points_per_batch, :, :]
  204. labels = input_labels[:, i : i + points_per_batch]
  205. is_last = i == n_points - points_per_batch
  206. yield {
  207. "input_points": batched_points,
  208. "input_labels": labels,
  209. "input_boxes": crop_boxes,
  210. "is_last": is_last,
  211. **model_inputs,
  212. }
  213. def _forward(
  214. self,
  215. model_inputs,
  216. pred_iou_thresh=0.88,
  217. stability_score_thresh=0.95,
  218. mask_threshold=0,
  219. stability_score_offset=1,
  220. max_hole_area=None,
  221. max_sprinkle_area=None,
  222. ):
  223. input_boxes = model_inputs.pop("input_boxes")
  224. is_last = model_inputs.pop("is_last")
  225. original_sizes = model_inputs.pop("original_sizes").tolist()
  226. reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
  227. model_outputs = self.model(**model_inputs)
  228. # post processing happens here in order to avoid CPU GPU copies of ALL the masks
  229. low_resolution_masks = model_outputs["pred_masks"]
  230. postprocess_kwargs = {}
  231. if max_hole_area is not None:
  232. postprocess_kwargs["max_hole_area"] = max_hole_area
  233. if max_sprinkle_area is not None and max_sprinkle_area > 0:
  234. postprocess_kwargs["max_sprinkle_area"] = max_sprinkle_area
  235. if postprocess_kwargs:
  236. low_resolution_masks = self.image_processor.post_process_masks(
  237. low_resolution_masks,
  238. original_sizes,
  239. mask_threshold=mask_threshold,
  240. reshaped_input_sizes=reshaped_input_sizes,
  241. binarize=False,
  242. **postprocess_kwargs,
  243. )
  244. masks = self.image_processor.post_process_masks(
  245. low_resolution_masks,
  246. original_sizes,
  247. mask_threshold=mask_threshold,
  248. reshaped_input_sizes=reshaped_input_sizes,
  249. binarize=False,
  250. )
  251. iou_scores = model_outputs["iou_scores"]
  252. masks, iou_scores, boxes = self.image_processor.filter_masks(
  253. masks[0],
  254. iou_scores[0],
  255. original_sizes[0],
  256. input_boxes[0],
  257. pred_iou_thresh,
  258. stability_score_thresh,
  259. mask_threshold,
  260. stability_score_offset,
  261. )
  262. return {
  263. "masks": masks,
  264. "is_last": is_last,
  265. "boxes": boxes,
  266. "iou_scores": iou_scores,
  267. }
  268. def postprocess(
  269. self,
  270. model_outputs,
  271. output_rle_mask=False,
  272. output_bboxes_mask=False,
  273. crops_nms_thresh=0.7,
  274. ):
  275. all_scores = []
  276. all_masks = []
  277. all_boxes = []
  278. for model_output in model_outputs:
  279. all_scores.append(model_output.pop("iou_scores"))
  280. all_masks.extend(model_output.pop("masks"))
  281. all_boxes.append(model_output.pop("boxes"))
  282. all_scores = torch.cat(all_scores)
  283. all_boxes = torch.cat(all_boxes)
  284. output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
  285. all_masks, all_scores, all_boxes, crops_nms_thresh
  286. )
  287. extra = defaultdict(list)
  288. for output in model_outputs:
  289. for k, v in output.items():
  290. extra[k].append(v)
  291. optional = {}
  292. if output_rle_mask:
  293. optional["rle_mask"] = rle_mask
  294. if output_bboxes_mask:
  295. optional["bounding_boxes"] = bounding_boxes
  296. return {"masks": output_masks, "scores": iou_scores, **optional, **extra}