s3dg.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # The implementation is adopted from https://github.com/TengdaHan/CoCLR,
  2. # made publicly available under the Apache License, Version 2.0 at https://github.com/TengdaHan/CoCLR
  3. # Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. class InceptionBaseConv3D(nn.Module):
  7. """
  8. Constructs basic inception 3D conv.
  9. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
  10. """
  11. def __init__(self,
  12. cfg,
  13. in_planes,
  14. out_planes,
  15. kernel_size,
  16. stride,
  17. padding=0):
  18. super(InceptionBaseConv3D, self).__init__()
  19. self.conv = nn.Conv3d(
  20. in_planes,
  21. out_planes,
  22. kernel_size=kernel_size,
  23. stride=stride,
  24. padding=padding,
  25. bias=False)
  26. self.bn = nn.BatchNorm3d(out_planes)
  27. self.relu = nn.ReLU(inplace=True)
  28. # init
  29. self.conv.weight.data.normal_(
  30. mean=0, std=0.01) # original s3d is truncated normal within 2 std
  31. self.bn.weight.data.fill_(1)
  32. self.bn.bias.data.zero_()
  33. def forward(self, x):
  34. x = self.conv(x)
  35. x = self.bn(x)
  36. x = self.relu(x)
  37. return x
  38. class InceptionBlock3D(nn.Module):
  39. """
  40. Element constructing the S3D/S3DG.
  41. See models/base/backbone.py L99-186.
  42. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
  43. """
  44. def __init__(self, cfg, in_planes, out_planes):
  45. super(InceptionBlock3D, self).__init__()
  46. _gating = cfg.VIDEO.BACKBONE.BRANCH.GATING
  47. assert len(out_planes) == 6
  48. assert isinstance(out_planes, list)
  49. [
  50. num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a,
  51. num_out_2_0b, num_out_3_0b
  52. ] = out_planes
  53. self.branch0 = nn.Sequential(
  54. InceptionBaseConv3D(
  55. cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), )
  56. self.branch1 = nn.Sequential(
  57. InceptionBaseConv3D(
  58. cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1),
  59. STConv3d(
  60. cfg,
  61. num_out_1_0a,
  62. num_out_1_0b,
  63. kernel_size=3,
  64. stride=1,
  65. padding=1),
  66. )
  67. self.branch2 = nn.Sequential(
  68. InceptionBaseConv3D(
  69. cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1),
  70. STConv3d(
  71. cfg,
  72. num_out_2_0a,
  73. num_out_2_0b,
  74. kernel_size=3,
  75. stride=1,
  76. padding=1),
  77. )
  78. self.branch3 = nn.Sequential(
  79. nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
  80. InceptionBaseConv3D(
  81. cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1),
  82. )
  83. self.out_channels = sum(
  84. [num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b])
  85. self.gating = _gating
  86. if _gating:
  87. self.gating_b0 = SelfGating(num_out_0_0a)
  88. self.gating_b1 = SelfGating(num_out_1_0b)
  89. self.gating_b2 = SelfGating(num_out_2_0b)
  90. self.gating_b3 = SelfGating(num_out_3_0b)
  91. def forward(self, x):
  92. x0 = self.branch0(x)
  93. x1 = self.branch1(x)
  94. x2 = self.branch2(x)
  95. x3 = self.branch3(x)
  96. if self.gating:
  97. x0 = self.gating_b0(x0)
  98. x1 = self.gating_b1(x1)
  99. x2 = self.gating_b2(x2)
  100. x3 = self.gating_b3(x3)
  101. out = torch.cat((x0, x1, x2, x3), 1)
  102. return out
  103. class SelfGating(nn.Module):
  104. def __init__(self, input_dim):
  105. super(SelfGating, self).__init__()
  106. self.fc = nn.Linear(input_dim, input_dim)
  107. def forward(self, input_tensor):
  108. """Feature gating as used in S3D-G"""
  109. spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4])
  110. weights = self.fc(spatiotemporal_average)
  111. weights = torch.sigmoid(weights)
  112. return weights[:, :, None, None, None] * input_tensor
  113. class STConv3d(nn.Module):
  114. """
  115. Element constructing the S3D/S3DG.
  116. See models/base/backbone.py L99-186.
  117. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
  118. """
  119. def __init__(self,
  120. cfg,
  121. in_planes,
  122. out_planes,
  123. kernel_size,
  124. stride,
  125. padding=0):
  126. super(STConv3d, self).__init__()
  127. if isinstance(stride, tuple):
  128. t_stride = stride[0]
  129. stride = stride[-1]
  130. else: # int
  131. t_stride = stride
  132. self.bn_mmt = cfg.BN.MOMENTUM
  133. self.bn_eps = float(cfg.BN.EPS)
  134. self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride,
  135. t_stride, padding)
  136. def _construct_branch(self,
  137. cfg,
  138. in_planes,
  139. out_planes,
  140. kernel_size,
  141. stride,
  142. t_stride,
  143. padding=0):
  144. self.conv1 = nn.Conv3d(
  145. in_planes,
  146. out_planes,
  147. kernel_size=(1, kernel_size, kernel_size),
  148. stride=(1, stride, stride),
  149. padding=(0, padding, padding),
  150. bias=False)
  151. self.conv2 = nn.Conv3d(
  152. out_planes,
  153. out_planes,
  154. kernel_size=(kernel_size, 1, 1),
  155. stride=(t_stride, 1, 1),
  156. padding=(padding, 0, 0),
  157. bias=False)
  158. self.bn1 = nn.BatchNorm3d(
  159. out_planes, eps=self.bn_eps, momentum=self.bn_mmt)
  160. self.bn2 = nn.BatchNorm3d(
  161. out_planes, eps=self.bn_eps, momentum=self.bn_mmt)
  162. self.relu = nn.ReLU(inplace=True)
  163. # init
  164. self.conv1.weight.data.normal_(
  165. mean=0, std=0.01) # original s3d is truncated normal within 2 std
  166. self.conv2.weight.data.normal_(
  167. mean=0, std=0.01) # original s3d is truncated normal within 2 std
  168. self.bn1.weight.data.fill_(1)
  169. self.bn1.bias.data.zero_()
  170. self.bn2.weight.data.fill_(1)
  171. self.bn2.bias.data.zero_()
  172. def forward(self, x):
  173. x = self.conv1(x)
  174. x = self.bn1(x)
  175. x = self.relu(x)
  176. x = self.conv2(x)
  177. x = self.bn2(x)
  178. x = self.relu(x)
  179. return x
  180. class Inception3D(nn.Module):
  181. """
  182. Backbone architecture for I3D/S3DG.
  183. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
  184. """
  185. def __init__(self, cfg):
  186. """
  187. Args:
  188. cfg (Config): global config object.
  189. """
  190. super(Inception3D, self).__init__()
  191. _input_channel = cfg.DATA.NUM_INPUT_CHANNELS
  192. self._construct_backbone(cfg, _input_channel)
  193. def _construct_backbone(self, cfg, input_channel):
  194. # ------------------- Block 1 -------------------
  195. self.Conv_1a = STConv3d(
  196. cfg, input_channel, 64, kernel_size=7, stride=2, padding=3)
  197. self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112)
  198. # ------------------- Block 2 -------------------
  199. self.MaxPool_2a = nn.MaxPool3d(
  200. kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
  201. self.Conv_2b = InceptionBaseConv3D(
  202. cfg, 64, 64, kernel_size=1, stride=1)
  203. self.Conv_2c = STConv3d(
  204. cfg, 64, 192, kernel_size=3, stride=1, padding=1)
  205. self.block2 = nn.Sequential(
  206. self.MaxPool_2a, # (64, 32, 56, 56)
  207. self.Conv_2b, # (64, 32, 56, 56)
  208. self.Conv_2c) # (192, 32, 56, 56)
  209. # ------------------- Block 3 -------------------
  210. self.MaxPool_3a = nn.MaxPool3d(
  211. kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
  212. self.Mixed_3b = InceptionBlock3D(
  213. cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32])
  214. self.Mixed_3c = InceptionBlock3D(
  215. cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64])
  216. self.block3 = nn.Sequential(
  217. self.MaxPool_3a, # (192, 32, 28, 28)
  218. self.Mixed_3b, # (256, 32, 28, 28)
  219. self.Mixed_3c) # (480, 32, 28, 28)
  220. # ------------------- Block 4 -------------------
  221. self.MaxPool_4a = nn.MaxPool3d(
  222. kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  223. self.Mixed_4b = InceptionBlock3D(
  224. cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64])
  225. self.Mixed_4c = InceptionBlock3D(
  226. cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64])
  227. self.Mixed_4d = InceptionBlock3D(
  228. cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64])
  229. self.Mixed_4e = InceptionBlock3D(
  230. cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64])
  231. self.Mixed_4f = InceptionBlock3D(
  232. cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128])
  233. self.block4 = nn.Sequential(
  234. self.MaxPool_4a, # (480, 16, 14, 14)
  235. self.Mixed_4b, # (512, 16, 14, 14)
  236. self.Mixed_4c, # (512, 16, 14, 14)
  237. self.Mixed_4d, # (512, 16, 14, 14)
  238. self.Mixed_4e, # (528, 16, 14, 14)
  239. self.Mixed_4f) # (832, 16, 14, 14)
  240. # ------------------- Block 5 -------------------
  241. self.MaxPool_5a = nn.MaxPool3d(
  242. kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
  243. self.Mixed_5b = InceptionBlock3D(
  244. cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128])
  245. self.Mixed_5c = InceptionBlock3D(
  246. cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128])
  247. self.block5 = nn.Sequential(
  248. self.MaxPool_5a, # (832, 8, 7, 7)
  249. self.Mixed_5b, # (832, 8, 7, 7)
  250. self.Mixed_5c) # (1024, 8, 7, 7)
  251. def forward(self, x):
  252. if isinstance(x, dict):
  253. x = x['video']
  254. x = self.block1(x)
  255. x = self.block2(x)
  256. x = self.block3(x)
  257. x = self.block4(x)
  258. x = self.block5(x)
  259. return x