rleg.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. """ Generative Multimodal Model Wrapper."""
  3. from typing import Any, Dict
  4. import torch
  5. from torchvision import transforms as T
  6. from modelscope.metainfo import Models
  7. from modelscope.models.base import TorchModel
  8. from modelscope.models.builder import MODELS
  9. from modelscope.models.multi_modal.rleg.model import RLEGModel
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.preprocessors import LoadImage
  12. from modelscope.utils.constant import ModelFile, Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. __all__ = ['RLEGForMultiModalEmbedding']
  16. @MODELS.register_module(
  17. Tasks.generative_multi_modal_embedding, module_name=Models.rleg)
  18. class RLEGForMultiModalEmbedding(TorchModel):
  19. """ Generative multi-modal model for multi-modal embedding.
  20. The model is trained by representation learning with embedding generation.
  21. Inputs could be image or text or both of them.
  22. Outputs could be features of input image or text,
  23. """
  24. def __init__(self, model_dir, device_id=0, *args, **kwargs):
  25. super().__init__(
  26. model_dir=model_dir, device_id=device_id, *args, **kwargs)
  27. self.model = RLEGModel(model_dir=model_dir)
  28. pretrained_params = torch.load('{}/{}'.format(
  29. model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
  30. self.model.load_state_dict(pretrained_params)
  31. self.model.eval()
  32. self.device_id = device_id
  33. if self.device_id >= 0 and torch.cuda.is_available():
  34. self.model.to('cuda:{}'.format(self.device_id))
  35. logger.info('Use GPU: {}'.format(self.device_id))
  36. else:
  37. self.device_id = -1
  38. logger.info('Use CPU for inference')
  39. self.img_preprocessor = T.Compose([
  40. T.Resize((224, 224)),
  41. T.ToTensor(),
  42. T.Normalize((0.48145466, 0.4578275, 0.40821073),
  43. (0.26862954, 0.26130258, 0.27577711))
  44. ])
  45. def parse_image(self, input_img):
  46. if input_img is None:
  47. return None
  48. input_img = LoadImage.convert_to_img(input_img)
  49. img_tensor = self.img_preprocessor(input_img)[None, ...]
  50. if self.device_id >= 0:
  51. img_tensor = img_tensor.to('cuda:{}'.format(self.device_id))
  52. return img_tensor
  53. def parse_text(self, text_str):
  54. if text_str is None or len(text_str) == 0:
  55. return None
  56. if isinstance(text_str, str):
  57. text_ids_tensor = self.model.tokenize(text_str)
  58. else:
  59. raise TypeError(f'text should be str, but got {type(text_str)}')
  60. if self.device_id >= 0:
  61. text_ids_tensor = text_ids_tensor.to('cuda:{}'.format(
  62. self.device_id))
  63. return text_ids_tensor.view(1, -1)
  64. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  65. image_input = input.get('image', input.get('img', None))
  66. text_input = input.get('text', input.get('txt', None))
  67. image = self.parse_image(image_input)
  68. text = self.parse_text(text_input)
  69. out = self.model(image, text)
  70. output = {
  71. OutputKeys.IMG_EMBEDDING: out.get('image_feature', None),
  72. OutputKeys.TEXT_EMBEDDING: out.get('text_feature', None),
  73. OutputKeys.CAPTION: out.get('caption', None)
  74. }
  75. return output