vision_middleware_pipeline.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import math
  3. import os.path as osp
  4. from typing import Any, Dict
  5. import numpy as np
  6. import torch
  7. import torchvision.transforms as transforms
  8. from mmcv.parallel import collate, scatter
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.cv.vision_middleware import VisionMiddlewareModel
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Input, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.preprocessors import LoadImage
  15. from modelscope.utils.config import Config
  16. from modelscope.utils.constant import ModelFile, Tasks
  17. from modelscope.utils.logger import get_logger
  18. logger = get_logger()
  19. @PIPELINES.register_module(
  20. Tasks.image_segmentation,
  21. module_name=Pipelines.vision_middleware_multi_task)
  22. class VisionMiddlewarePipeline(Pipeline):
  23. def __init__(self, model: str, **kwargs):
  24. """
  25. use `model` to create a vision middleware pipeline for prediction
  26. Args:
  27. model: model id on modelscope hub.
  28. """
  29. super().__init__(model=model, **kwargs)
  30. self.model = self.model.cuda()
  31. self.model.eval()
  32. self.transform = transforms.Compose([
  33. transforms.Resize((224, 224)),
  34. transforms.ToTensor(),
  35. transforms.Normalize(
  36. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  37. ])
  38. def preprocess(self, input: Input) -> Dict[str, Any]:
  39. img = LoadImage.convert_to_img(input)
  40. data = self.transform(img)
  41. data = collate([data], samples_per_gpu=1)
  42. if next(self.model.parameters()).is_cuda:
  43. # scatter to specified GPU
  44. data = scatter(data, [next(self.model.parameters()).device])[0]
  45. return data
  46. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  47. with torch.no_grad():
  48. # currently only support one task in pipeline
  49. results = self.model(input, task_name='seg-voc')
  50. return results
  51. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  52. return inputs