| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
- import os
- import torch
- import torch.nn as nn
- from modelscope.metainfo import Models
- from modelscope.models.base.base_torch_model import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.config import Config
- from modelscope.utils.constant import ModelFile, Tasks
- from .blocks import (BboxRegressor, Q2VRankerStage1, Q2VRankerStage2,
- V2QRankerStage1, V2QRankerStage2)
- from .swin_transformer import SwinTransformerV2_1D
- @MODELS.register_module(
- Tasks.video_temporal_grounding, module_name=Models.soonet)
- class SOONet(TorchModel):
- """
- The implementation of 'Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding
- in Long Videos'. The model is dynamically initialized with the following parts:
- - q2v_stage1: calculate qv_ctx_score.
- - v2q_stage1: calculate vq_ctx_score.
- - q2v_stage2: calculate qv_ctn_score.
- - v2q_stage2: calculate vq_ctn_score.
- - regressor: predict the offset of bounding box for each candidate anchor.
- """
- def __init__(self, model_dir: str, *args, **kwargs):
- """
- Initialize SOONet Model
- Args:
- model_dir: model id or path
- """
- super().__init__()
- config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
- self.config = Config.from_file(config_path).hyperparams
- nscales = self.config.nscales
- hidden_dim = self.config.hidden_dim
- snippet_length = self.config.snippet_length
- self.enable_stage2 = self.config.enable_stage2
- self.stage2_topk = self.config.stage2_topk
- self.nscales = nscales
- self.video_encoder = SwinTransformerV2_1D(
- patch_size=snippet_length,
- in_chans=hidden_dim,
- embed_dim=hidden_dim,
- depths=[2] * nscales,
- num_heads=[8] * nscales,
- window_size=[64] * nscales,
- mlp_ratio=2.,
- qkv_bias=True,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- norm_layer=nn.LayerNorm,
- patch_norm=True,
- use_checkpoint=False,
- pretrained_window_sizes=[0] * nscales)
- self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim)
- self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim)
- if self.enable_stage2:
- self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim,
- snippet_length)
- self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim)
- self.regressor = BboxRegressor(hidden_dim, self.enable_stage2)
- # Load trained weights
- model_path = os.path.join(model_dir,
- 'SOONet_MAD_VIT-B-32_4Scale_10C.pth')
- state_dict = torch.load(model_path, map_location='cpu')['model']
- self.load_state_dict(state_dict, strict=True)
- def forward(self, **kwargs):
- if self.training:
- return self.forward_train(**kwargs)
- else:
- return self.forward_test(**kwargs)
- def forward_train(self, **kwargs):
- raise NotImplementedError
- def forward_test(self,
- query_feats=None,
- video_feats=None,
- start_ts=None,
- end_ts=None,
- scale_boundaries=None,
- **kwargs):
- """
- Obtain matching scores and bbox bias of the top-k candidate anchors, with
- pre-extracted query features and video features as input.
- Args:
- query_feats: the pre-extracted text features.
- video_feats: the pre-extracted video features.
- start_ts: the start timestamps of pre-defined multi-scale anchors.
- end_ts: the end timestamps of pre-defined multi-scale anchors.
- scale_boundaries: the begin and end anchor index for each scale in start_ts and end_ts.
- Returns:
- [final_scores, bbox_bias, starts, ends]
- """
- sent_feat = query_feats
- ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1))
- qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat)
- if self.enable_stage2:
- hit_indices = list()
- starts = list()
- ends = list()
- filtered_ctx_feats = list()
- for i in range(self.nscales):
- _, indices = torch.sort(
- qv_ctx_scores[i], dim=1, descending=True)
- indices, _ = torch.sort(
- torch.LongTensor(
- list(
- set(indices[:, :self.stage2_topk].flatten().cpu().
- numpy().tolist()))))
- indices = indices.to(video_feats.device)
- hit_indices.append(indices)
- filtered_ctx_feats.append(
- torch.index_select(ctx_feats[i], 1, indices))
- scale_first = scale_boundaries[i]
- scale_last = scale_boundaries[i + 1]
- filtered_start = torch.index_select(
- start_ts[scale_first:scale_last], 0, indices)
- filtered_end = torch.index_select(
- end_ts[scale_first:scale_last], 0, indices)
- starts.append(filtered_start)
- ends.append(filtered_end)
- starts = torch.cat(starts, dim=0)
- ends = torch.cat(ends, dim=0)
- qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2(
- video_feats, sent_feat, hit_indices, qv_ctx_scores)
- ctx_feats = filtered_ctx_feats
- else:
- ctn_feats = None
- qv_merge_scores = qv_ctx_scores
- starts = start_ts
- ends = end_ts
- bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat)
- final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1))
- return final_scores, bbox_bias, starts, ends
|