crowd_counting_pipeline.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. from typing import Any, Dict
  4. import numpy as np
  5. import torch
  6. import torchvision.transforms as transforms
  7. from PIL import Image
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.models.cv.crowd_counting import HRNetCrowdCounting
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import Input, Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.preprocessors.image import LoadImage
  14. from modelscope.utils.constant import Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. @PIPELINES.register_module(
  18. Tasks.crowd_counting, module_name=Pipelines.crowd_counting)
  19. class CrowdCountingPipeline(Pipeline):
  20. def __init__(self, model: str, **kwargs):
  21. """
  22. model: model id on modelscope hub.
  23. """
  24. assert isinstance(model, str), 'model must be a single str'
  25. super().__init__(model=model, auto_collate=False, **kwargs)
  26. logger.info(f'loading model from dir {model}')
  27. self.infer_model = HRNetCrowdCounting(model).to(self.device)
  28. self.infer_model.eval()
  29. logger.info('load model done')
  30. def resize(self, img):
  31. height = img.size[1]
  32. width = img.size[0]
  33. resize_height = height
  34. resize_width = width
  35. if resize_width >= 2048:
  36. tmp = resize_width
  37. resize_width = 2048
  38. resize_height = (resize_width / tmp) * resize_height
  39. if resize_height >= 2048:
  40. tmp = resize_height
  41. resize_height = 2048
  42. resize_width = (resize_height / tmp) * resize_width
  43. if resize_height <= 416:
  44. tmp = resize_height
  45. resize_height = 416
  46. resize_width = (resize_height / tmp) * resize_width
  47. if resize_width <= 416:
  48. tmp = resize_width
  49. resize_width = 416
  50. resize_height = (resize_width / tmp) * resize_height
  51. # other constraints
  52. if resize_height < resize_width:
  53. if resize_width / resize_height > 2048 / 416: # 1024/416=2.46
  54. resize_width = 2048
  55. resize_height = 416
  56. else:
  57. if resize_height / resize_width > 2048 / 416:
  58. resize_height = 2048
  59. resize_width = 416
  60. resize_height = math.ceil(resize_height / 32) * 32
  61. resize_width = math.ceil(resize_width / 32) * 32
  62. img = transforms.Resize([resize_height, resize_width])(img)
  63. return img
  64. def merge_crops(self, eval_shape, eval_p, pred_m):
  65. for i in range(3):
  66. for j in range(3):
  67. start_h, start_w = math.floor(eval_shape[2] / 4), math.floor(
  68. eval_shape[3] / 4)
  69. valid_h, valid_w = eval_shape[2] // 2, eval_shape[3] // 2
  70. pred_h = math.floor(
  71. 3 * eval_shape[2] / 4) + (eval_shape[2] // 2) * (
  72. i - 1)
  73. pred_w = math.floor(
  74. 3 * eval_shape[3] / 4) + (eval_shape[3] // 2) * (
  75. j - 1)
  76. if i == 0:
  77. valid_h = math.floor(3 * eval_shape[2] / 4)
  78. start_h = 0
  79. pred_h = 0
  80. elif i == 2:
  81. valid_h = math.ceil(3 * eval_shape[2] / 4)
  82. if j == 0:
  83. valid_w = math.floor(3 * eval_shape[3] / 4)
  84. start_w = 0
  85. pred_w = 0
  86. elif j == 2:
  87. valid_w = math.ceil(3 * eval_shape[3] / 4)
  88. pred_m[:, :, pred_h:pred_h + valid_h, pred_w:pred_w
  89. + valid_w] += eval_p[i * 3 + j:i * 3 + j + 1, :,
  90. start_h:start_h + valid_h,
  91. start_w:start_w + valid_w]
  92. return pred_m
  93. def preprocess(self, input: Input) -> Dict[str, Any]:
  94. img = LoadImage.convert_to_img(input)
  95. img = self.resize(img)
  96. img_ori_tensor = transforms.ToTensor()(img)
  97. img_shape = img_ori_tensor.shape
  98. img = transforms.Normalize((0.485, 0.456, 0.406),
  99. (0.229, 0.224, 0.225))(
  100. img_ori_tensor)
  101. patch_height, patch_width = (img_shape[1]) // 2, (img_shape[2]) // 2
  102. imgs = []
  103. for i in range(3):
  104. for j in range(3):
  105. start_h, start_w = (patch_height // 2) * i, (patch_width
  106. // 2) * j
  107. imgs.append(img[:, start_h:start_h + patch_height,
  108. start_w:start_w + patch_width])
  109. imgs = torch.stack(imgs)
  110. eval_img = imgs.to(self.device)
  111. eval_patchs = torch.squeeze(eval_img)
  112. prediction_map = torch.zeros(
  113. (1, 1, img_shape[1] // 2, img_shape[2] // 2)).to(self.device)
  114. result = {
  115. 'img': eval_patchs,
  116. 'map': prediction_map,
  117. }
  118. return result
  119. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  120. counts, img_data = self.perform_inference(input)
  121. return {OutputKeys.SCORES: counts, OutputKeys.OUTPUT_IMG: img_data}
  122. @torch.no_grad()
  123. def perform_inference(self, data):
  124. eval_patchs = data['img']
  125. prediction_map = data['map']
  126. eval_prediction, _, _ = self.infer_model(eval_patchs)
  127. eval_patchs_shape = eval_prediction.shape
  128. prediction_map = self.merge_crops(eval_patchs_shape, eval_prediction,
  129. prediction_map)
  130. return torch.sum(
  131. prediction_map, dim=(
  132. 1, 2,
  133. 3)).data.cpu().numpy(), prediction_map.data.cpu().numpy()[0][0]
  134. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  135. return inputs