blocks.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class Q2VRankerStage1(nn.Module):
  7. """
  8. Used to calculate the qv_ctx_score with query embedding and multi anchor context embeddings as input.
  9. The qv_ctx_score is used to pre-rank and retain top-k related anchors.
  10. """
  11. def __init__(self, nscales, hidden_dim):
  12. super().__init__()
  13. self.fc = nn.Linear(hidden_dim, hidden_dim)
  14. self.nscales = nscales
  15. def forward(self, ctx_feats, qfeat):
  16. qfeat = self.fc(qfeat)
  17. qv_ctx_scores = list()
  18. for i in range(self.nscales):
  19. score = torch.einsum('bld,bd->bl',
  20. F.normalize(ctx_feats[i], p=2, dim=2),
  21. F.normalize(qfeat, p=2, dim=1))
  22. qv_ctx_scores.append(score)
  23. return qv_ctx_scores
  24. class V2QRankerStage1(nn.Module):
  25. """
  26. Used to calculate the vq_ctx_score with anchor context embeddings and multi query embeddings as input.
  27. """
  28. def __init__(self, nscales, hidden_dim):
  29. super().__init__()
  30. self.fc = nn.Linear(hidden_dim, hidden_dim)
  31. self.nscales = nscales
  32. def forward(self, ctx_feats, qfeat):
  33. vq_ctx_scores = list()
  34. for i in range(self.nscales):
  35. score = torch.einsum(
  36. 'bld,bd->bl', F.normalize(self.fc(ctx_feats[i]), p=2, dim=2),
  37. F.normalize(qfeat, p=2, dim=1))
  38. vq_ctx_scores.append(score)
  39. return vq_ctx_scores
  40. class Q2VRankerStage2(nn.Module):
  41. """
  42. Used to calculate the qv_ctn_score with query embedding and video sequence embedding as input.
  43. The qv_ctn_score is used to re-rank anchors.
  44. """
  45. def __init__(self, nscales, hidden_dim, snippet_length=10):
  46. super().__init__()
  47. self.nscales = nscales
  48. self.snippet_length = snippet_length
  49. self.qfc = nn.Linear(hidden_dim, hidden_dim)
  50. self.encoder = V2VAttention()
  51. def forward(self, vfeats, qfeat, hit_indices, qv_ctx_scores):
  52. qfeat = self.qfc(qfeat)
  53. qv_ctn_scores = list()
  54. qv_merge_scores = list()
  55. _, L, D = vfeats.size()
  56. ctn_feats = list()
  57. for i in range(self.nscales):
  58. anchor_length = self.snippet_length * 2**i
  59. assert L // anchor_length == qv_ctx_scores[i].size(1)
  60. qv_ctx_score = torch.index_select(qv_ctx_scores[i], 1,
  61. hit_indices[i])
  62. ctn_feat = vfeats.view(L // anchor_length, anchor_length,
  63. D).detach()
  64. ctn_feat = torch.index_select(ctn_feat, 0, hit_indices[i])
  65. ctn_feat = self.encoder(
  66. ctn_feat,
  67. torch.ones(ctn_feat.size()[:2], device=ctn_feat.device))
  68. ctn_feats.append(ctn_feat)
  69. qv_ctn_score = torch.einsum(
  70. 'bkld,bd->bkl', F.normalize(ctn_feat.unsqueeze(0), p=2, dim=3),
  71. F.normalize(qfeat, p=2, dim=1))
  72. qv_ctn_score, _ = torch.max(qv_ctn_score, dim=2)
  73. qv_ctn_scores.append(qv_ctn_score)
  74. qv_merge_scores.append(qv_ctx_score + qv_ctn_score)
  75. return qv_merge_scores, qv_ctn_scores, ctn_feats
  76. class V2QRankerStage2(nn.Module):
  77. """
  78. Used to calculate the vq_ctn_score with anchor content embeddings and multi query embeddings as input.
  79. """
  80. def __init__(self, nscales, hidden_dim):
  81. super().__init__()
  82. self.fc = nn.Linear(hidden_dim, hidden_dim)
  83. self.nscales = nscales
  84. def forward(self, ctn_feats, qfeat):
  85. vq_ctn_scores = list()
  86. for i in range(self.nscales):
  87. score = torch.einsum(
  88. 'bkld,bd->bkl',
  89. F.normalize(self.fc(ctn_feats[i]).unsqueeze(0), p=2, dim=3),
  90. F.normalize(qfeat, p=2, dim=1))
  91. score = torch.mean(score, dim=2)
  92. vq_ctn_scores.append(score)
  93. return vq_ctn_scores
  94. class V2VAttention(nn.Module):
  95. """
  96. Self-attention encoder for anchor frame sequence to encode intra-anchor knowledge.
  97. """
  98. def __init__(self):
  99. super().__init__()
  100. self.posemb = PositionEncoding(max_len=400, dim=512, dropout=0.0)
  101. self.encoder = MultiHeadAttention(dim=512, n_heads=8, dropout=0.1)
  102. self.dropout = nn.Dropout(0.0)
  103. def forward(self, video_feats, video_masks):
  104. mask = torch.einsum('bm,bn->bmn', video_masks,
  105. video_masks).unsqueeze(1)
  106. residual = video_feats
  107. video_feats = video_feats + self.posemb(video_feats)
  108. out = self.encoder(
  109. query=video_feats, key=video_feats, value=video_feats, mask=mask)
  110. video_feats = self.dropout(residual
  111. + out) * video_masks.unsqueeze(2).float()
  112. return video_feats
  113. class BboxRegressor(nn.Module):
  114. """
  115. Predict the offset of bounding box for each candidate anchor.
  116. """
  117. def __init__(self, hidden_dim, enable_stage2=False):
  118. super().__init__()
  119. self.fc_ctx = nn.Linear(hidden_dim, hidden_dim)
  120. self.fc_q = nn.Linear(hidden_dim, hidden_dim)
  121. if enable_stage2:
  122. self.fc_ctn = nn.Linear(hidden_dim, hidden_dim)
  123. self.attn = SelfAttention(hidden_dim)
  124. self.predictor = nn.Sequential(
  125. nn.Linear(2 * hidden_dim, hidden_dim), nn.ReLU(),
  126. nn.Linear(hidden_dim, 2))
  127. else:
  128. self.predictor = nn.Sequential(
  129. nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
  130. nn.Linear(hidden_dim, 2))
  131. self.enable_stage2 = enable_stage2
  132. def forward(self, ctx_feats, ctn_feats, qfeat):
  133. qfeat = self.fc_q(qfeat)
  134. ctx_feats = torch.cat(ctx_feats, dim=1)
  135. ctx_fuse_feats = F.relu(self.fc_ctx(ctx_feats)) * F.relu(
  136. qfeat.unsqueeze(1))
  137. if self.enable_stage2 and ctn_feats:
  138. ctn_fuse_feats = list()
  139. for i in range(len(ctn_feats)):
  140. out = F.relu(self.fc_ctn(ctn_feats[i]).unsqueeze(0)) * F.relu(
  141. qfeat.unsqueeze(1).unsqueeze(1))
  142. out = self.attn(out)
  143. ctn_fuse_feats.append(out)
  144. ctn_fuse_feats = torch.cat(ctn_fuse_feats, dim=1)
  145. fuse_feats = torch.cat([ctx_fuse_feats, ctn_fuse_feats], dim=-1)
  146. else:
  147. fuse_feats = ctx_fuse_feats
  148. out = self.predictor(fuse_feats)
  149. return out
  150. class SelfAttention(nn.Module):
  151. """
  152. Obtain pooled features by self-attentive pooling.
  153. """
  154. def __init__(self, hidden_dim):
  155. super().__init__()
  156. self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
  157. self.relu = nn.ReLU()
  158. self.fc2 = nn.Linear(hidden_dim // 2, 1)
  159. def forward(self, x):
  160. att = self.fc2(self.relu(self.fc1(x))).squeeze(3)
  161. att = F.softmax(att, dim=2).unsqueeze(3)
  162. out = torch.sum(x * att, dim=2)
  163. return out
  164. class PositionEncoding(nn.Module):
  165. """
  166. An implementation of trainable positional embedding which is added to
  167. sequence features to inject time/position information.
  168. Args:
  169. max_len: The max number of trainable positional embeddings.
  170. dim: the dimension of positional embedding.
  171. """
  172. def __init__(self, max_len, dim, dropout=0.0):
  173. super(PositionEncoding, self).__init__()
  174. self.embed = nn.Embedding(max_len, dim)
  175. self.relu = nn.ReLU()
  176. self.dropout = nn.Dropout(dropout)
  177. def forward(self, x):
  178. batch_size, seq_len = x.shape[:2]
  179. pos_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
  180. pos_ids = pos_ids.unsqueeze(0).repeat(batch_size, 1)
  181. pos_emb = self.dropout(self.relu(self.embed(pos_ids)))
  182. return pos_emb
  183. class MultiHeadAttention(nn.Module):
  184. """
  185. An implementation of multi-head attention module, as described in
  186. 'Attention Is All You Need <https://arxiv.org/abs/1706.03762>'
  187. Args:
  188. dim: the dimension of features of hidden layers.
  189. n_heads: the number of head.
  190. """
  191. def __init__(self, dim, n_heads, dropout=0.0):
  192. super(MultiHeadAttention, self).__init__()
  193. self.dim = dim
  194. self.n_heads = n_heads
  195. self.head_dim = dim // n_heads
  196. self.to_q = nn.Linear(dim, dim)
  197. self.to_k = nn.Linear(dim, dim)
  198. self.to_v = nn.Linear(dim, dim)
  199. self.dropout = nn.Dropout(dropout)
  200. self.softmax = nn.Softmax(dim=-1)
  201. def transpose_for_scores(self, x):
  202. new_x_shape = x.size()[:-1] + (self.n_heads, self.head_dim)
  203. x = x.view(*new_x_shape)
  204. return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
  205. def forward(self, query, key, value, mask):
  206. q = self.to_q(query)
  207. k = self.to_k(key)
  208. v = self.to_v(value)
  209. q_trans = self.transpose_for_scores(q)
  210. k_trans = self.transpose_for_scores(k)
  211. v_trans = self.transpose_for_scores(v)
  212. att = torch.matmul(q_trans, k_trans.transpose(-1,
  213. -2)) # (N, nh, Lq, L)
  214. att = att / math.sqrt(self.head_dim)
  215. att = mask_logits(att, mask)
  216. att = self.softmax(att)
  217. att = self.dropout(att)
  218. ctx_v = torch.matmul(att, v_trans) # (N, nh, Lq, dh)
  219. ctx_v = ctx_v.permute(0, 2, 1, 3).contiguous() # (N, Lq, nh, dh)
  220. shape = ctx_v.size()[:-2] + (self.dim, )
  221. ctx_v = ctx_v.view(*shape) # (N, Lq, D)
  222. return ctx_v
  223. def mask_logits(inputs, mask, mask_value=-1e30):
  224. mask = mask.type(torch.float32)
  225. return inputs + (1.0 - mask) * mask_value