content_check_pipeline.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import cv2
  5. import numpy as np
  6. import PIL
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from torchvision import transforms
  11. from modelscope.metainfo import Pipelines
  12. from modelscope.outputs import OutputKeys
  13. from modelscope.pipelines import pipeline
  14. from modelscope.pipelines.base import Input, Pipeline
  15. from modelscope.pipelines.builder import PIPELINES
  16. from modelscope.preprocessors import LoadImage
  17. from modelscope.utils.constant import ModelFile, Tasks
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. @PIPELINES.register_module(
  21. Tasks.image_classification, module_name=Pipelines.content_check)
  22. class ContentCheckPipeline(Pipeline):
  23. def __init__(self, model: str, **kwargs):
  24. """
  25. use `model` to create a content check pipeline for prediction
  26. Args:
  27. model: model id on modelscope hub.
  28. Example:
  29. ContentCheckPipeline can judge whether the picture is pornographic
  30. ```python
  31. >>> from modelscope.pipelines import pipeline
  32. >>> cc_func = pipeline('image_classification', 'damo/cv_resnet50_image-classification_cc')
  33. >>> cc_func("https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/content_check.jpg")
  34. {'scores': [0.2789826989173889], 'labels': 'pornographic'}
  35. ```
  36. """
  37. # content check model
  38. super().__init__(model=model, **kwargs)
  39. self.test_transforms = transforms.Compose([
  40. transforms.Resize(224),
  41. transforms.CenterCrop(224),
  42. transforms.ToTensor(),
  43. transforms.Normalize(
  44. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  45. ])
  46. logger.info('content check model loaded!')
  47. def preprocess(self, input: Input) -> Dict[str, Any]:
  48. img = LoadImage.convert_to_img(input)
  49. img = self.test_transforms(img).float()
  50. result = {}
  51. result['img'] = img
  52. return result
  53. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  54. img = input['img'].unsqueeze(0)
  55. result = self.model(img)
  56. score = [1 - F.softmax(result[:, :5])[0][-1].tolist()]
  57. if score[0] < 0.5:
  58. label = 'pornographic'
  59. else:
  60. label = 'normal'
  61. return {OutputKeys.SCORES: score, OutputKeys.LABELS: label}
  62. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  63. return inputs