model.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import os
  3. import torch
  4. import torch.nn as nn
  5. from modelscope.metainfo import Models
  6. from modelscope.models.base.base_torch_model import TorchModel
  7. from modelscope.models.builder import MODELS
  8. from modelscope.utils.config import Config
  9. from modelscope.utils.constant import ModelFile, Tasks
  10. from .blocks import (BboxRegressor, Q2VRankerStage1, Q2VRankerStage2,
  11. V2QRankerStage1, V2QRankerStage2)
  12. from .swin_transformer import SwinTransformerV2_1D
  13. @MODELS.register_module(
  14. Tasks.video_temporal_grounding, module_name=Models.soonet)
  15. class SOONet(TorchModel):
  16. """
  17. The implementation of 'Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding
  18. in Long Videos'. The model is dynamically initialized with the following parts:
  19. - q2v_stage1: calculate qv_ctx_score.
  20. - v2q_stage1: calculate vq_ctx_score.
  21. - q2v_stage2: calculate qv_ctn_score.
  22. - v2q_stage2: calculate vq_ctn_score.
  23. - regressor: predict the offset of bounding box for each candidate anchor.
  24. """
  25. def __init__(self, model_dir: str, *args, **kwargs):
  26. """
  27. Initialize SOONet Model
  28. Args:
  29. model_dir: model id or path
  30. """
  31. super().__init__()
  32. config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
  33. self.config = Config.from_file(config_path).hyperparams
  34. nscales = self.config.nscales
  35. hidden_dim = self.config.hidden_dim
  36. snippet_length = self.config.snippet_length
  37. self.enable_stage2 = self.config.enable_stage2
  38. self.stage2_topk = self.config.stage2_topk
  39. self.nscales = nscales
  40. self.video_encoder = SwinTransformerV2_1D(
  41. patch_size=snippet_length,
  42. in_chans=hidden_dim,
  43. embed_dim=hidden_dim,
  44. depths=[2] * nscales,
  45. num_heads=[8] * nscales,
  46. window_size=[64] * nscales,
  47. mlp_ratio=2.,
  48. qkv_bias=True,
  49. drop_rate=0.,
  50. attn_drop_rate=0.,
  51. drop_path_rate=0.1,
  52. norm_layer=nn.LayerNorm,
  53. patch_norm=True,
  54. use_checkpoint=False,
  55. pretrained_window_sizes=[0] * nscales)
  56. self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim)
  57. self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim)
  58. if self.enable_stage2:
  59. self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim,
  60. snippet_length)
  61. self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim)
  62. self.regressor = BboxRegressor(hidden_dim, self.enable_stage2)
  63. # Load trained weights
  64. model_path = os.path.join(model_dir,
  65. 'SOONet_MAD_VIT-B-32_4Scale_10C.pth')
  66. state_dict = torch.load(model_path, map_location='cpu')['model']
  67. self.load_state_dict(state_dict, strict=True)
  68. def forward(self, **kwargs):
  69. if self.training:
  70. return self.forward_train(**kwargs)
  71. else:
  72. return self.forward_test(**kwargs)
  73. def forward_train(self, **kwargs):
  74. raise NotImplementedError
  75. def forward_test(self,
  76. query_feats=None,
  77. video_feats=None,
  78. start_ts=None,
  79. end_ts=None,
  80. scale_boundaries=None,
  81. **kwargs):
  82. """
  83. Obtain matching scores and bbox bias of the top-k candidate anchors, with
  84. pre-extracted query features and video features as input.
  85. Args:
  86. query_feats: the pre-extracted text features.
  87. video_feats: the pre-extracted video features.
  88. start_ts: the start timestamps of pre-defined multi-scale anchors.
  89. end_ts: the end timestamps of pre-defined multi-scale anchors.
  90. scale_boundaries: the begin and end anchor index for each scale in start_ts and end_ts.
  91. Returns:
  92. [final_scores, bbox_bias, starts, ends]
  93. """
  94. sent_feat = query_feats
  95. ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1))
  96. qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat)
  97. if self.enable_stage2:
  98. hit_indices = list()
  99. starts = list()
  100. ends = list()
  101. filtered_ctx_feats = list()
  102. for i in range(self.nscales):
  103. _, indices = torch.sort(
  104. qv_ctx_scores[i], dim=1, descending=True)
  105. indices, _ = torch.sort(
  106. torch.LongTensor(
  107. list(
  108. set(indices[:, :self.stage2_topk].flatten().cpu().
  109. numpy().tolist()))))
  110. indices = indices.to(video_feats.device)
  111. hit_indices.append(indices)
  112. filtered_ctx_feats.append(
  113. torch.index_select(ctx_feats[i], 1, indices))
  114. scale_first = scale_boundaries[i]
  115. scale_last = scale_boundaries[i + 1]
  116. filtered_start = torch.index_select(
  117. start_ts[scale_first:scale_last], 0, indices)
  118. filtered_end = torch.index_select(
  119. end_ts[scale_first:scale_last], 0, indices)
  120. starts.append(filtered_start)
  121. ends.append(filtered_end)
  122. starts = torch.cat(starts, dim=0)
  123. ends = torch.cat(ends, dim=0)
  124. qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2(
  125. video_feats, sent_feat, hit_indices, qv_ctx_scores)
  126. ctx_feats = filtered_ctx_feats
  127. else:
  128. ctn_feats = None
  129. qv_merge_scores = qv_ctx_scores
  130. starts = start_ts
  131. ends = end_ts
  132. bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat)
  133. final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1))
  134. return final_scores, bbox_bias, starts, ends