vop_retrieval_pipeline.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import gzip
  3. import math
  4. import os
  5. import os.path as osp
  6. import pickle
  7. import random
  8. from collections import defaultdict, deque
  9. from typing import Any, Dict
  10. import numpy as np
  11. import torch
  12. from tqdm import tqdm
  13. from modelscope.metainfo import Pipelines
  14. from modelscope.models import Model
  15. from modelscope.models.cv.vop_retrieval import (LengthAdaptiveTokenizer, VoP,
  16. init_transform_dict, load_data,
  17. load_frames_from_video)
  18. from modelscope.outputs import OutputKeys
  19. from modelscope.pipelines.base import Input, Pipeline
  20. from modelscope.pipelines.builder import PIPELINES
  21. from modelscope.preprocessors import load_image
  22. from modelscope.utils.config import Config
  23. from modelscope.utils.constant import ModelFile, Tasks
  24. from modelscope.utils.logger import get_logger
  25. logger = get_logger()
  26. @PIPELINES.register_module(
  27. Tasks.vop_retrieval, module_name=Pipelines.vop_retrieval)
  28. class VopRetrievalPipeline(Pipeline):
  29. def __init__(self, model: str, **kwargs):
  30. """
  31. use `model` to create a vop pipeline for retrieval
  32. Args:
  33. model: model id on modelscope hub.
  34. """
  35. super().__init__(model=model, **kwargs)
  36. # [from pretrain] load model
  37. self.model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop').to(
  38. self.device)
  39. logger.info('load model done')
  40. # others: load transform
  41. self.local_pth = model
  42. self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION))
  43. self.img_transform = init_transform_dict(
  44. self.cfg.hyperparam.input_res)['clip_test']
  45. logger.info('load transform done')
  46. # others: load tokenizer
  47. bpe_path = gzip.open(osp.join(
  48. model,
  49. 'bpe_simple_vocab_16e6.txt.gz')).read().decode('utf-8').split('\n')
  50. self.tokenizer = LengthAdaptiveTokenizer(self.cfg.hyperparam, bpe_path)
  51. logger.info('load tokenizer done')
  52. # others: load dataset
  53. self.database = load_data(
  54. osp.join(model, 'VoP_msrvtt9k_features.pkl'), self.device)
  55. logger.info('load database done')
  56. def preprocess(self, input: Input) -> Dict[str, Any]:
  57. if isinstance(input, str):
  58. if '.mp4' in input:
  59. query = []
  60. for video_path in [input]:
  61. video_path = osp.join(self.local_pth, video_path)
  62. imgs, idxs = load_frames_from_video(
  63. video_path, self.cfg.hyperparam.num_frames,
  64. self.cfg.hyperparam.video_sample_type)
  65. imgs = self.img_transform(imgs)
  66. query.append(imgs)
  67. query = torch.stack(
  68. query, dim=0).to(
  69. self.device, non_blocking=True)
  70. mode = 'v2t'
  71. else:
  72. query = self.tokenizer(
  73. input, return_tensors='pt', padding=True, truncation=True)
  74. if isinstance(query, torch.Tensor):
  75. query = query.to(self.device, non_blocking=True)
  76. else:
  77. query = {
  78. key: val.to(self.device, non_blocking=True)
  79. for key, val in query.items()
  80. }
  81. mode = 't2v'
  82. else:
  83. raise TypeError(f'input should be a str,'
  84. f' but got {type(input)}')
  85. result = {'input_data': query, 'mode': mode}
  86. return result
  87. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  88. text_embeds, vid_embeds_pooled, vid_ids, texts = self.database
  89. with torch.no_grad():
  90. if input['mode'] == 't2v':
  91. query_feats = self.model.get_text_features(input['input_data'])
  92. score = query_feats @ vid_embeds_pooled.T
  93. retrieval_idxs = torch.topk(
  94. score, k=self.cfg.hyperparam.topk,
  95. dim=-1)[1].cpu().numpy()
  96. res = np.array(vid_ids)[retrieval_idxs]
  97. elif input['mode'] == 'v2t':
  98. query_feats = self.model.get_video_features(
  99. input['input_data'])
  100. score = query_feats @ text_embeds.T
  101. retrieval_idxs = torch.topk(
  102. score, k=self.cfg.hyperparam.topk,
  103. dim=-1)[1].cpu().numpy()
  104. res = np.array(texts)[retrieval_idxs]
  105. results = {'output_data': res, 'mode': input['mode']}
  106. return results
  107. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  108. return inputs