| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os.path as osp
- from typing import Any, Dict
- import cv2
- import numpy as np
- import PIL
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torchvision import transforms
- from modelscope.metainfo import Pipelines
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines import pipeline
- from modelscope.pipelines.base import Input, Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors import LoadImage
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @PIPELINES.register_module(
- Tasks.image_classification, module_name=Pipelines.content_check)
- class ContentCheckPipeline(Pipeline):
- def __init__(self, model: str, **kwargs):
- """
- use `model` to create a content check pipeline for prediction
- Args:
- model: model id on modelscope hub.
- Example:
- ContentCheckPipeline can judge whether the picture is pornographic
- ```python
- >>> from modelscope.pipelines import pipeline
- >>> cc_func = pipeline('image_classification', 'damo/cv_resnet50_image-classification_cc')
- >>> cc_func("https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/content_check.jpg")
- {'scores': [0.2789826989173889], 'labels': 'pornographic'}
- ```
- """
- # content check model
- super().__init__(model=model, **kwargs)
- self.test_transforms = transforms.Compose([
- transforms.Resize(224),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- logger.info('content check model loaded!')
- def preprocess(self, input: Input) -> Dict[str, Any]:
- img = LoadImage.convert_to_img(input)
- img = self.test_transforms(img).float()
- result = {}
- result['img'] = img
- return result
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
- img = input['img'].unsqueeze(0)
- result = self.model(img)
- score = [1 - F.softmax(result[:, :5])[0][-1].tolist()]
- if score[0] < 0.5:
- label = 'pornographic'
- else:
- label = 'normal'
- return {OutputKeys.SCORES: score, OutputKeys.LABELS: label}
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- return inputs
|