face_recognition_pipeline.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import cv2
  5. import numpy as np
  6. import PIL
  7. import torch
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.models.cv.face_recognition.align_face import align_face
  10. from modelscope.models.cv.face_recognition.torchkit.backbone import get_model
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines import pipeline
  13. from modelscope.pipelines.base import Input, Pipeline
  14. from modelscope.pipelines.builder import PIPELINES
  15. from modelscope.pipelines.cv.face_processing_base_pipeline import \
  16. FaceProcessingBasePipeline
  17. from modelscope.preprocessors import LoadImage
  18. from modelscope.utils.constant import ModelFile, Tasks
  19. from modelscope.utils.logger import get_logger
  20. logger = get_logger()
  21. @PIPELINES.register_module(
  22. Tasks.face_recognition, module_name=Pipelines.face_recognition)
  23. class FaceRecognitionPipeline(FaceProcessingBasePipeline):
  24. def __init__(self, model: str, use_det=True, **kwargs):
  25. """
  26. use `model` to create a face recognition pipeline for prediction
  27. Args:
  28. model: model id on modelscope hub.
  29. """
  30. # face recong model
  31. super().__init__(model=model, use_det=use_det, **kwargs)
  32. device = torch.device(
  33. f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
  34. self.device = device
  35. face_model = get_model('IR_101')([112, 112])
  36. face_model.load_state_dict(
  37. torch.load(
  38. osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE),
  39. map_location=device,
  40. weights_only=True))
  41. face_model = face_model.to(device)
  42. face_model.eval()
  43. self.face_model = face_model
  44. logger.info('face recognition model loaded!')
  45. def preprocess(self, input: Input) -> Dict[str, Any]:
  46. result = super().preprocess(input)
  47. align_img = result['img']
  48. face_img = align_img[:, :, ::-1] # to rgb
  49. face_img = np.transpose(face_img, axes=(2, 0, 1))
  50. face_img = (face_img / 255. - 0.5) / 0.5
  51. face_img = face_img.astype(np.float32)
  52. result['img'] = face_img
  53. return result
  54. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  55. assert input['img'] is not None
  56. img = input['img'].unsqueeze(0)
  57. emb = self.face_model(img).detach().cpu().numpy()
  58. emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm
  59. return {OutputKeys.IMG_EMBEDDING: emb}
  60. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  61. return inputs