| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- # The implementation is adopted from https://github.com/TengdaHan/CoCLR,
- # made publicly available under the Apache License, Version 2.0 at https://github.com/TengdaHan/CoCLR
- # Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved.
- import torch
- import torch.nn as nn
- class InceptionBaseConv3D(nn.Module):
- """
- Constructs basic inception 3D conv.
- Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
- """
- def __init__(self,
- cfg,
- in_planes,
- out_planes,
- kernel_size,
- stride,
- padding=0):
- super(InceptionBaseConv3D, self).__init__()
- self.conv = nn.Conv3d(
- in_planes,
- out_planes,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- bias=False)
- self.bn = nn.BatchNorm3d(out_planes)
- self.relu = nn.ReLU(inplace=True)
- # init
- self.conv.weight.data.normal_(
- mean=0, std=0.01) # original s3d is truncated normal within 2 std
- self.bn.weight.data.fill_(1)
- self.bn.bias.data.zero_()
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
- class InceptionBlock3D(nn.Module):
- """
- Element constructing the S3D/S3DG.
- See models/base/backbone.py L99-186.
- Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
- """
- def __init__(self, cfg, in_planes, out_planes):
- super(InceptionBlock3D, self).__init__()
- _gating = cfg.VIDEO.BACKBONE.BRANCH.GATING
- assert len(out_planes) == 6
- assert isinstance(out_planes, list)
- [
- num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a,
- num_out_2_0b, num_out_3_0b
- ] = out_planes
- self.branch0 = nn.Sequential(
- InceptionBaseConv3D(
- cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), )
- self.branch1 = nn.Sequential(
- InceptionBaseConv3D(
- cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1),
- STConv3d(
- cfg,
- num_out_1_0a,
- num_out_1_0b,
- kernel_size=3,
- stride=1,
- padding=1),
- )
- self.branch2 = nn.Sequential(
- InceptionBaseConv3D(
- cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1),
- STConv3d(
- cfg,
- num_out_2_0a,
- num_out_2_0b,
- kernel_size=3,
- stride=1,
- padding=1),
- )
- self.branch3 = nn.Sequential(
- nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
- InceptionBaseConv3D(
- cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1),
- )
- self.out_channels = sum(
- [num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b])
- self.gating = _gating
- if _gating:
- self.gating_b0 = SelfGating(num_out_0_0a)
- self.gating_b1 = SelfGating(num_out_1_0b)
- self.gating_b2 = SelfGating(num_out_2_0b)
- self.gating_b3 = SelfGating(num_out_3_0b)
- def forward(self, x):
- x0 = self.branch0(x)
- x1 = self.branch1(x)
- x2 = self.branch2(x)
- x3 = self.branch3(x)
- if self.gating:
- x0 = self.gating_b0(x0)
- x1 = self.gating_b1(x1)
- x2 = self.gating_b2(x2)
- x3 = self.gating_b3(x3)
- out = torch.cat((x0, x1, x2, x3), 1)
- return out
- class SelfGating(nn.Module):
- def __init__(self, input_dim):
- super(SelfGating, self).__init__()
- self.fc = nn.Linear(input_dim, input_dim)
- def forward(self, input_tensor):
- """Feature gating as used in S3D-G"""
- spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4])
- weights = self.fc(spatiotemporal_average)
- weights = torch.sigmoid(weights)
- return weights[:, :, None, None, None] * input_tensor
- class STConv3d(nn.Module):
- """
- Element constructing the S3D/S3DG.
- See models/base/backbone.py L99-186.
- Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
- """
- def __init__(self,
- cfg,
- in_planes,
- out_planes,
- kernel_size,
- stride,
- padding=0):
- super(STConv3d, self).__init__()
- if isinstance(stride, tuple):
- t_stride = stride[0]
- stride = stride[-1]
- else: # int
- t_stride = stride
- self.bn_mmt = cfg.BN.MOMENTUM
- self.bn_eps = float(cfg.BN.EPS)
- self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride,
- t_stride, padding)
- def _construct_branch(self,
- cfg,
- in_planes,
- out_planes,
- kernel_size,
- stride,
- t_stride,
- padding=0):
- self.conv1 = nn.Conv3d(
- in_planes,
- out_planes,
- kernel_size=(1, kernel_size, kernel_size),
- stride=(1, stride, stride),
- padding=(0, padding, padding),
- bias=False)
- self.conv2 = nn.Conv3d(
- out_planes,
- out_planes,
- kernel_size=(kernel_size, 1, 1),
- stride=(t_stride, 1, 1),
- padding=(padding, 0, 0),
- bias=False)
- self.bn1 = nn.BatchNorm3d(
- out_planes, eps=self.bn_eps, momentum=self.bn_mmt)
- self.bn2 = nn.BatchNorm3d(
- out_planes, eps=self.bn_eps, momentum=self.bn_mmt)
- self.relu = nn.ReLU(inplace=True)
- # init
- self.conv1.weight.data.normal_(
- mean=0, std=0.01) # original s3d is truncated normal within 2 std
- self.conv2.weight.data.normal_(
- mean=0, std=0.01) # original s3d is truncated normal within 2 std
- self.bn1.weight.data.fill_(1)
- self.bn1.bias.data.zero_()
- self.bn2.weight.data.fill_(1)
- self.bn2.bias.data.zero_()
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- return x
- class Inception3D(nn.Module):
- """
- Backbone architecture for I3D/S3DG.
- Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
- """
- def __init__(self, cfg):
- """
- Args:
- cfg (Config): global config object.
- """
- super(Inception3D, self).__init__()
- _input_channel = cfg.DATA.NUM_INPUT_CHANNELS
- self._construct_backbone(cfg, _input_channel)
- def _construct_backbone(self, cfg, input_channel):
- # ------------------- Block 1 -------------------
- self.Conv_1a = STConv3d(
- cfg, input_channel, 64, kernel_size=7, stride=2, padding=3)
- self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112)
- # ------------------- Block 2 -------------------
- self.MaxPool_2a = nn.MaxPool3d(
- kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
- self.Conv_2b = InceptionBaseConv3D(
- cfg, 64, 64, kernel_size=1, stride=1)
- self.Conv_2c = STConv3d(
- cfg, 64, 192, kernel_size=3, stride=1, padding=1)
- self.block2 = nn.Sequential(
- self.MaxPool_2a, # (64, 32, 56, 56)
- self.Conv_2b, # (64, 32, 56, 56)
- self.Conv_2c) # (192, 32, 56, 56)
- # ------------------- Block 3 -------------------
- self.MaxPool_3a = nn.MaxPool3d(
- kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
- self.Mixed_3b = InceptionBlock3D(
- cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32])
- self.Mixed_3c = InceptionBlock3D(
- cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64])
- self.block3 = nn.Sequential(
- self.MaxPool_3a, # (192, 32, 28, 28)
- self.Mixed_3b, # (256, 32, 28, 28)
- self.Mixed_3c) # (480, 32, 28, 28)
- # ------------------- Block 4 -------------------
- self.MaxPool_4a = nn.MaxPool3d(
- kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
- self.Mixed_4b = InceptionBlock3D(
- cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64])
- self.Mixed_4c = InceptionBlock3D(
- cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64])
- self.Mixed_4d = InceptionBlock3D(
- cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64])
- self.Mixed_4e = InceptionBlock3D(
- cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64])
- self.Mixed_4f = InceptionBlock3D(
- cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128])
- self.block4 = nn.Sequential(
- self.MaxPool_4a, # (480, 16, 14, 14)
- self.Mixed_4b, # (512, 16, 14, 14)
- self.Mixed_4c, # (512, 16, 14, 14)
- self.Mixed_4d, # (512, 16, 14, 14)
- self.Mixed_4e, # (528, 16, 14, 14)
- self.Mixed_4f) # (832, 16, 14, 14)
- # ------------------- Block 5 -------------------
- self.MaxPool_5a = nn.MaxPool3d(
- kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
- self.Mixed_5b = InceptionBlock3D(
- cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128])
- self.Mixed_5c = InceptionBlock3D(
- cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128])
- self.block5 = nn.Sequential(
- self.MaxPool_5a, # (832, 8, 7, 7)
- self.Mixed_5b, # (832, 8, 7, 7)
- self.Mixed_5c) # (1024, 8, 7, 7)
- def forward(self, x):
- if isinstance(x, dict):
- x = x['video']
- x = self.block1(x)
- x = self.block2(x)
- x = self.block3(x)
- x = self.block4(x)
- x = self.block5(x)
- return x
|