vidt_pipeline.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. from typing import Any, Dict
  3. import torch
  4. import torchvision.transforms as transforms
  5. from torch import nn
  6. from modelscope.metainfo import Pipelines
  7. from modelscope.pipelines.base import Input, Pipeline
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.preprocessors import LoadImage
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.logger import get_logger
  12. logger = get_logger()
  13. @PIPELINES.register_module(
  14. Tasks.image_object_detection, module_name=Pipelines.vidt)
  15. class VidtPipeline(Pipeline):
  16. def __init__(self, model: str, **kwargs):
  17. """
  18. use `model` to create a vidt pipeline for prediction
  19. Args:
  20. model: model id on modelscope hub.
  21. Example:
  22. >>> from modelscope.pipelines import pipeline
  23. >>> vidt_pipeline = pipeline('image-object-detection', 'damo/ViDT-logo-detection')
  24. >>> result = vidt_pipeline(
  25. 'data/test/images/vidt_test1.png')
  26. >>> print(f'Output: {result}.')
  27. """
  28. super().__init__(model=model, **kwargs)
  29. self.model.eval()
  30. self.transform = transforms.Compose([
  31. transforms.Resize([640, 640]),
  32. transforms.ToTensor(),
  33. transforms.Normalize(
  34. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  35. ])
  36. self.postprocessors = PostProcess()
  37. self.label_dic = {0: 'negative', 1: 'positive'}
  38. def preprocess(self, inputs: Input, **preprocess_params):
  39. img = LoadImage.convert_to_img(inputs)
  40. ori_size = [img.size[1], img.size[0]]
  41. image = self.transform(img)
  42. tensor_list = [image]
  43. orig_target_sizes = [ori_size]
  44. orig_target_sizes = torch.tensor(orig_target_sizes).to(self.device)
  45. samples = nested_tensor_from_tensor_list(tensor_list)
  46. samples = samples.to(self.device)
  47. res = {}
  48. res['tensors'] = samples.tensors
  49. res['mask'] = samples.mask
  50. res['orig_target_sizes'] = orig_target_sizes
  51. return res
  52. def forward(self, inputs: Dict[str, Any], **forward_params):
  53. tensors = inputs['tensors']
  54. mask = inputs['mask']
  55. orig_target_sizes = inputs['orig_target_sizes']
  56. with torch.no_grad():
  57. out_pred_logits, out_pred_boxes = self.model(tensors, mask)
  58. res = {}
  59. res['out_pred_logits'] = out_pred_logits
  60. res['out_pred_boxes'] = out_pred_boxes
  61. res['orig_target_sizes'] = orig_target_sizes
  62. return res
  63. def postprocess(self, inputs: Dict[str, Any], **post_params):
  64. results = self.postprocessors(inputs['out_pred_logits'],
  65. inputs['out_pred_boxes'],
  66. inputs['orig_target_sizes'])
  67. batch_predictions = get_predictions(results)[0] # 仅支持单张图推理
  68. scores = []
  69. labels = []
  70. boxes = []
  71. for sub_pre in batch_predictions:
  72. scores.append(sub_pre[0])
  73. labels.append(self.label_dic[sub_pre[1]])
  74. boxes.append(sub_pre[2]) # [xmin, ymin, xmax, ymax]
  75. outputs = {}
  76. outputs['scores'] = scores
  77. outputs['labels'] = labels
  78. outputs['boxes'] = boxes
  79. return outputs
  80. def nested_tensor_from_tensor_list(tensor_list):
  81. # TODO make it support different-sized images
  82. max_size = _max_by_axis([list(img.shape) for img in tensor_list])
  83. batch_shape = [len(tensor_list)] + max_size
  84. b, c, h, w = batch_shape
  85. dtype = tensor_list[0].dtype
  86. device = tensor_list[0].device
  87. tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
  88. mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
  89. for img, pad_img, m in zip(tensor_list, tensor, mask):
  90. pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)
  91. m[:img.shape[1], :img.shape[2]] = False
  92. return NestedTensor(tensor, mask)
  93. def _max_by_axis(the_list):
  94. # type: (List[List[int]]) -> List[int]
  95. maxes = the_list[0]
  96. for sublist in the_list[1:]:
  97. for index, item in enumerate(sublist):
  98. maxes[index] = max(maxes[index], item)
  99. return maxes
  100. class NestedTensor(object):
  101. def __init__(self, tensors, mask):
  102. self.tensors = tensors
  103. self.mask = mask
  104. def to(self, device):
  105. # type: (Device) -> NestedTensor # noqa
  106. cast_tensor = self.tensors.to(device)
  107. mask = self.mask
  108. if mask is not None:
  109. assert mask is not None
  110. cast_mask = mask.to(device)
  111. else:
  112. cast_mask = None
  113. return NestedTensor(cast_tensor, cast_mask)
  114. def decompose(self):
  115. return self.tensors, self.mask
  116. def __repr__(self):
  117. return str(self.tensors)
  118. def box_cxcywh_to_xyxy(x):
  119. x_c, y_c, w, h = x.unbind(-1)
  120. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
  121. return torch.stack(b, dim=-1)
  122. # process post_results
  123. def get_predictions(post_results, bbox_thu=0.40):
  124. batch_final_res = []
  125. for per_img_res in post_results:
  126. per_img_final_res = []
  127. for i in range(len(per_img_res['scores'])):
  128. score = float(per_img_res['scores'][i].cpu())
  129. label = int(per_img_res['labels'][i].cpu())
  130. bbox = []
  131. for it in per_img_res['boxes'][i].cpu():
  132. bbox.append(int(it))
  133. if score >= bbox_thu:
  134. per_img_final_res.append([score, label, bbox])
  135. batch_final_res.append(per_img_final_res)
  136. return batch_final_res
  137. class PostProcess(nn.Module):
  138. """ This module converts the model's output into the format expected by the coco api"""
  139. def __init__(self, processor_dct=None):
  140. super().__init__()
  141. # For instance segmentation using UQR module
  142. self.processor_dct = processor_dct
  143. @torch.no_grad()
  144. def forward(self, out_logits, out_bbox, target_sizes):
  145. """ Perform the computation
  146. Parameters:
  147. out_logits: raw logits outputs of the model
  148. out_bbox: raw bbox outputs of the model
  149. target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
  150. For evaluation, this must be the original image size (before any data augmentation)
  151. For visualization, this should be the image size after data augment, but before padding
  152. """
  153. assert len(out_logits) == len(target_sizes)
  154. assert target_sizes.shape[1] == 2
  155. prob = out_logits.sigmoid()
  156. topk_values, topk_indexes = torch.topk(
  157. prob.view(out_logits.shape[0], -1), 100, dim=1)
  158. scores = topk_values
  159. topk_boxes = topk_indexes // out_logits.shape[2]
  160. labels = topk_indexes % out_logits.shape[2]
  161. boxes = box_cxcywh_to_xyxy(out_bbox)
  162. boxes = torch.gather(boxes, 1,
  163. topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
  164. # and from relative [0, 1] to absolute [0, height] coordinates
  165. img_h, img_w = target_sizes.unbind(1)
  166. scale_fct = torch.stack([img_w, img_h, img_w, img_h],
  167. dim=1).to(torch.float32)
  168. boxes = boxes * scale_fct[:, None, :]
  169. results = [{
  170. 'scores': s,
  171. 'labels': l,
  172. 'boxes': b
  173. } for s, l, b in zip(scores, labels, boxes)]
  174. return results