| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Any, Dict
- import cv2
- import einops
- import numpy as np
- import requests
- import torch
- from PIL import Image
- from modelscope.metainfo import Pipelines
- from modelscope.models.cv.anydoor.cldm.ddim_hacked import DDIMSampler
- from modelscope.models.cv.anydoor.datasets.data_utils import (
- box2squre, box_in_box, expand_bbox, expand_image_mask, get_bbox_from_mask,
- pad_to_square, sobel)
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Input, Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors.image import load_image
- from modelscope.utils.constant import Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @PIPELINES.register_module(
- Tasks.image_to_image_generation, module_name=Pipelines.anydoor)
- class AnydoorPipeline(Pipeline):
- r""" AnyDoor Pipeline.
- Examples:
- >>> from modelscope.pipelines import pipeline
- >>> from modelscope.utils.constant import Tasks
- >>> from PIL import Image
- >>> ref_image = 'data/test/images/image_anydoor_fg.png'
- >>> ref_mask = 'data/test/images/image_anydoor_fg_mask.png'
- >>> bg_image = 'data/test/images/image_anydoor_bg.png'
- >>> bg_mask = 'data/test/images/image_anydoor_bg_mask.png'
- >>> anydoor_pipeline = pipeline(Tasks.image_to_image_generation, model='damo/AnyDoor')
- >>> out = anydoor_pipeline((ref_image, ref_mask, bg_image, bg_mask))
- >>> assert isinstance(out['output_img'], Image.Image)
- """
- def __init__(self, model: str, **kwargs):
- """
- use `model` to create a action detection pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, **kwargs)
- model_ckpt = os.path.join(self.model.model_dir,
- self.cfg.model.model_path)
- self.model.load_state_dict(
- self._get_state_dict(model_ckpt, location='cuda'))
- self.ddim_sampler = DDIMSampler(self.model)
- @staticmethod
- def _get_state_dict(ckpt_path, location='cpu'):
- def get_state_dict(d):
- return d.get('state_dict', d)
- _, extension = os.path.splitext(ckpt_path)
- if extension.lower() == '.safetensors':
- import safetensors.torch
- state_dict = safetensors.torch.load_file(
- ckpt_path, device=location)
- else:
- state_dict = get_state_dict(
- torch.load(
- ckpt_path,
- map_location=torch.device(location),
- weights_only=True))
- state_dict = get_state_dict(state_dict)
- print(f'Loaded state_dict from [{ckpt_path}]')
- return state_dict
- def preprocess(self, inputs: Input) -> Dict[str, Any]:
- ref_image, ref_mask, tar_image, tar_mask = inputs
- ref_image = np.asarray(load_image(ref_image).convert('RGB'))
- ref_mask = np.where(
- np.asarray(load_image(ref_mask).convert('L')) > 128, 1,
- 0).astype(np.uint8)
- tar_image = np.asarray(load_image(tar_image).convert('RGB'))
- tar_mask = np.where(
- np.asarray(load_image(tar_mask).convert('L')) > 128, 1,
- 0).astype(np.uint8)
- # ========= Reference ===========
- # ref expand
- ref_box_yyxx = get_bbox_from_mask(ref_mask)
- # ref filter mask
- ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
- masked_ref_image = ref_image * ref_mask_3 + np.ones_like(
- ref_image) * 255 * (1 - ref_mask_3)
- y1, y2, x1, x2 = ref_box_yyxx
- masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
- ref_mask = ref_mask[y1:y2, x1:x2]
- ratio = np.random.randint(11, 15) / 10 # 11,13
- masked_ref_image, ref_mask = expand_image_mask(
- masked_ref_image, ref_mask, ratio=ratio)
- ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
- # to square and resize
- masked_ref_image = pad_to_square(
- masked_ref_image, pad_value=255, random=False)
- masked_ref_image = cv2.resize(
- masked_ref_image.astype(np.uint8), (224, 224)).astype(np.uint8)
- ref_mask_3 = pad_to_square(ref_mask_3 * 255, pad_value=0, random=False)
- ref_mask_3 = cv2.resize(ref_mask_3.astype(np.uint8),
- (224, 224)).astype(np.uint8)
- ref_mask = ref_mask_3[:, :, 0]
- # collage aug
- masked_ref_image_compose, ref_mask_compose = masked_ref_image, ref_mask
- ref_mask_3 = np.stack(
- [ref_mask_compose, ref_mask_compose, ref_mask_compose], -1)
- ref_image_collage = sobel(masked_ref_image_compose,
- ref_mask_compose / 255)
- # ========= Target ===========
- tar_box_yyxx = get_bbox_from_mask(tar_mask)
- tar_box_yyxx = expand_bbox(
- tar_mask, tar_box_yyxx, ratio=[1.1, 1.2]) # 1.1 1.3
- # crop
- tar_box_yyxx_crop = expand_bbox(
- tar_image, tar_box_yyxx, ratio=[1.3, 3.0])
- tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
- y1, y2, x1, x2 = tar_box_yyxx_crop
- cropped_target_image = tar_image[y1:y2, x1:x2, :]
- cropped_tar_mask = tar_mask[y1:y2, x1:x2]
- tar_box_yyxx = box_in_box(tar_box_yyxx, tar_box_yyxx_crop)
- y1, y2, x1, x2 = tar_box_yyxx
- # collage
- ref_image_collage = cv2.resize(
- ref_image_collage.astype(np.uint8), (x2 - x1, y2 - y1))
- ref_mask_compose = cv2.resize(
- ref_mask_compose.astype(np.uint8), (x2 - x1, y2 - y1))
- ref_mask_compose = (ref_mask_compose > 128).astype(np.uint8)
- collage = cropped_target_image.copy()
- collage[y1:y2, x1:x2, :] = ref_image_collage
- collage_mask = cropped_target_image.copy() * 0.0
- collage_mask[y1:y2, x1:x2, :] = 1.0
- collage_mask = np.stack(
- [cropped_tar_mask, cropped_tar_mask, cropped_tar_mask], -1)
- # the size before pad
- H1, W1 = collage.shape[0], collage.shape[1]
- cropped_target_image = pad_to_square(
- cropped_target_image, pad_value=0, random=False).astype(np.uint8)
- collage = pad_to_square(
- collage, pad_value=0, random=False).astype(np.uint8)
- collage_mask = pad_to_square(
- collage_mask, pad_value=0, random=False).astype(np.uint8)
- # the size after pad
- H2, W2 = collage.shape[0], collage.shape[1]
- cropped_target_image = cv2.resize(
- cropped_target_image.astype(np.uint8),
- (512, 512)).astype(np.float32)
- collage = cv2.resize(collage.astype(np.uint8),
- (512, 512)).astype(np.float32)
- collage_mask = (cv2.resize(collage_mask.astype(
- np.uint8), (512, 512)).astype(np.float32) > 0.5).astype(np.float32)
- masked_ref_image = masked_ref_image / 255
- cropped_target_image = cropped_target_image / 127.5 - 1.0
- collage = collage / 127.5 - 1.0
- collage = np.concatenate([collage, collage_mask[:, :, :1]], -1)
- item = dict(
- tar_image=tar_image,
- ref=masked_ref_image.copy(),
- jpg=cropped_target_image.copy(),
- hint=collage.copy(),
- extra_sizes=np.array([H1, W1, H2, W2]),
- tar_box_yyxx_crop=np.array(tar_box_yyxx_crop))
- return item
- def forward(self,
- item: Dict[str, Any],
- num_samples=1,
- strength=1.0,
- ddim_steps=30,
- scale=3.0) -> Dict[str, Any]:
- tar_image = item['tar_image'].cpu().numpy()
- ref = item['ref']
- hint = item['hint']
- num_samples = 1
- control = hint.float().cuda()
- control = torch.stack([control for _ in range(num_samples)], dim=0)
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
- clip_input = ref.float().cuda()
- clip_input = torch.stack([clip_input for _ in range(num_samples)],
- dim=0)
- clip_input = einops.rearrange(clip_input, 'b h w c -> b c h w').clone()
- H, W = 512, 512
- cond = {
- 'c_concat': [control],
- 'c_crossattn': [self.model.get_learned_conditioning(clip_input)]
- }
- un_cond = {
- 'c_concat': [control],
- 'c_crossattn': [
- self.model.get_learned_conditioning(
- [torch.zeros((1, 3, 224, 224))] * num_samples)
- ]
- }
- shape = (4, H // 8, W // 8)
- self.model.control_scales = ([strength] * 13)
- samples, _ = self.ddim_sampler.sample(
- ddim_steps,
- num_samples,
- shape,
- cond,
- verbose=False,
- eta=0,
- unconditional_guidance_scale=scale,
- unconditional_conditioning=un_cond)
- x_samples = self.model.decode_first_stage(samples)
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5
- + 127.5).cpu().numpy()
- result = x_samples[0][:, :, ::-1]
- result = np.clip(result, 0, 255)
- pred = x_samples[0]
- pred = np.clip(pred, 0, 255)[1:, :, :]
- sizes = item['extra_sizes'].cpu().numpy()
- tar_box_yyxx_crop = item['tar_box_yyxx_crop'].cpu().numpy()
- return dict(
- pred=pred,
- tar_image=tar_image,
- sizes=sizes,
- tar_box_yyxx_crop=tar_box_yyxx_crop)
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- pred = inputs['pred']
- tar_image = inputs['tar_image']
- extra_sizes = inputs['sizes']
- tar_box_yyxx_crop = inputs['tar_box_yyxx_crop']
- H1, W1, H2, W2 = extra_sizes
- y1, y2, x1, x2 = tar_box_yyxx_crop
- pred = cv2.resize(pred, (W2, H2))
- m = 3 # maigin_pixel
- if W1 == H1:
- tar_image[y1 + m:y2 - m, x1 + m:x2 - m, :] = pred[m:-m, m:-m]
- gen_image = torch.from_numpy(tar_image.copy()).permute(2, 0, 1)
- gen_image = gen_image.permute(1, 2, 0).numpy()
- gen_image = Image.fromarray(gen_image, mode='RGB')
- return {OutputKeys.OUTPUT_IMG: gen_image}
- if W1 < W2:
- pad1 = int((W2 - W1) / 2)
- pad2 = W2 - W1 - pad1
- pred = pred[:, pad1:-pad2, :]
- else:
- pad1 = int((H2 - H1) / 2)
- pad2 = H2 - H1 - pad1
- pred = pred[pad1:-pad2, :, :]
- gen_image = tar_image.copy()
- gen_image[y1 + m:y2 - m, x1 + m:x2 - m, :] = pred[m:-m, m:-m]
- gen_image = torch.from_numpy(gen_image).permute(2, 0, 1)
- gen_image = gen_image.permute(1, 2, 0).numpy()
- gen_image = Image.fromarray(gen_image, mode='RGB')
- return {OutputKeys.OUTPUT_IMG: gen_image}
|