# The implementation is adopted from https://github.com/TengdaHan/CoCLR, # made publicly available under the Apache License, Version 2.0 at https://github.com/TengdaHan/CoCLR # Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved. import torch import torch.nn as nn class InceptionBaseConv3D(nn.Module): """ Constructs basic inception 3D conv. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. """ def __init__(self, cfg, in_planes, out_planes, kernel_size, stride, padding=0): super(InceptionBaseConv3D, self).__init__() self.conv = nn.Conv3d( in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm3d(out_planes) self.relu = nn.ReLU(inplace=True) # init self.conv.weight.data.normal_( mean=0, std=0.01) # original s3d is truncated normal within 2 std self.bn.weight.data.fill_(1) self.bn.bias.data.zero_() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class InceptionBlock3D(nn.Module): """ Element constructing the S3D/S3DG. See models/base/backbone.py L99-186. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. """ def __init__(self, cfg, in_planes, out_planes): super(InceptionBlock3D, self).__init__() _gating = cfg.VIDEO.BACKBONE.BRANCH.GATING assert len(out_planes) == 6 assert isinstance(out_planes, list) [ num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a, num_out_2_0b, num_out_3_0b ] = out_planes self.branch0 = nn.Sequential( InceptionBaseConv3D( cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), ) self.branch1 = nn.Sequential( InceptionBaseConv3D( cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1), STConv3d( cfg, num_out_1_0a, num_out_1_0b, kernel_size=3, stride=1, padding=1), ) self.branch2 = nn.Sequential( InceptionBaseConv3D( cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1), STConv3d( cfg, num_out_2_0a, num_out_2_0b, kernel_size=3, stride=1, padding=1), ) self.branch3 = nn.Sequential( nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), InceptionBaseConv3D( cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1), ) self.out_channels = sum( [num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) self.gating = _gating if _gating: self.gating_b0 = SelfGating(num_out_0_0a) self.gating_b1 = SelfGating(num_out_1_0b) self.gating_b2 = SelfGating(num_out_2_0b) self.gating_b3 = SelfGating(num_out_3_0b) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) if self.gating: x0 = self.gating_b0(x0) x1 = self.gating_b1(x1) x2 = self.gating_b2(x2) x3 = self.gating_b3(x3) out = torch.cat((x0, x1, x2, x3), 1) return out class SelfGating(nn.Module): def __init__(self, input_dim): super(SelfGating, self).__init__() self.fc = nn.Linear(input_dim, input_dim) def forward(self, input_tensor): """Feature gating as used in S3D-G""" spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) weights = self.fc(spatiotemporal_average) weights = torch.sigmoid(weights) return weights[:, :, None, None, None] * input_tensor class STConv3d(nn.Module): """ Element constructing the S3D/S3DG. See models/base/backbone.py L99-186. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. """ def __init__(self, cfg, in_planes, out_planes, kernel_size, stride, padding=0): super(STConv3d, self).__init__() if isinstance(stride, tuple): t_stride = stride[0] stride = stride[-1] else: # int t_stride = stride self.bn_mmt = cfg.BN.MOMENTUM self.bn_eps = float(cfg.BN.EPS) self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride, t_stride, padding) def _construct_branch(self, cfg, in_planes, out_planes, kernel_size, stride, t_stride, padding=0): self.conv1 = nn.Conv3d( in_planes, out_planes, kernel_size=(1, kernel_size, kernel_size), stride=(1, stride, stride), padding=(0, padding, padding), bias=False) self.conv2 = nn.Conv3d( out_planes, out_planes, kernel_size=(kernel_size, 1, 1), stride=(t_stride, 1, 1), padding=(padding, 0, 0), bias=False) self.bn1 = nn.BatchNorm3d( out_planes, eps=self.bn_eps, momentum=self.bn_mmt) self.bn2 = nn.BatchNorm3d( out_planes, eps=self.bn_eps, momentum=self.bn_mmt) self.relu = nn.ReLU(inplace=True) # init self.conv1.weight.data.normal_( mean=0, std=0.01) # original s3d is truncated normal within 2 std self.conv2.weight.data.normal_( mean=0, std=0.01) # original s3d is truncated normal within 2 std self.bn1.weight.data.fill_(1) self.bn1.bias.data.zero_() self.bn2.weight.data.fill_(1) self.bn2.bias.data.zero_() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) return x class Inception3D(nn.Module): """ Backbone architecture for I3D/S3DG. Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. """ def __init__(self, cfg): """ Args: cfg (Config): global config object. """ super(Inception3D, self).__init__() _input_channel = cfg.DATA.NUM_INPUT_CHANNELS self._construct_backbone(cfg, _input_channel) def _construct_backbone(self, cfg, input_channel): # ------------------- Block 1 ------------------- self.Conv_1a = STConv3d( cfg, input_channel, 64, kernel_size=7, stride=2, padding=3) self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) # ------------------- Block 2 ------------------- self.MaxPool_2a = nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) self.Conv_2b = InceptionBaseConv3D( cfg, 64, 64, kernel_size=1, stride=1) self.Conv_2c = STConv3d( cfg, 64, 192, kernel_size=3, stride=1, padding=1) self.block2 = nn.Sequential( self.MaxPool_2a, # (64, 32, 56, 56) self.Conv_2b, # (64, 32, 56, 56) self.Conv_2c) # (192, 32, 56, 56) # ------------------- Block 3 ------------------- self.MaxPool_3a = nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) self.Mixed_3b = InceptionBlock3D( cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32]) self.Mixed_3c = InceptionBlock3D( cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64]) self.block3 = nn.Sequential( self.MaxPool_3a, # (192, 32, 28, 28) self.Mixed_3b, # (256, 32, 28, 28) self.Mixed_3c) # (480, 32, 28, 28) # ------------------- Block 4 ------------------- self.MaxPool_4a = nn.MaxPool3d( kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) self.Mixed_4b = InceptionBlock3D( cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64]) self.Mixed_4c = InceptionBlock3D( cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64]) self.Mixed_4d = InceptionBlock3D( cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64]) self.Mixed_4e = InceptionBlock3D( cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64]) self.Mixed_4f = InceptionBlock3D( cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128]) self.block4 = nn.Sequential( self.MaxPool_4a, # (480, 16, 14, 14) self.Mixed_4b, # (512, 16, 14, 14) self.Mixed_4c, # (512, 16, 14, 14) self.Mixed_4d, # (512, 16, 14, 14) self.Mixed_4e, # (528, 16, 14, 14) self.Mixed_4f) # (832, 16, 14, 14) # ------------------- Block 5 ------------------- self.MaxPool_5a = nn.MaxPool3d( kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) self.Mixed_5b = InceptionBlock3D( cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128]) self.Mixed_5c = InceptionBlock3D( cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128]) self.block5 = nn.Sequential( self.MaxPool_5a, # (832, 8, 7, 7) self.Mixed_5b, # (832, 8, 7, 7) self.Mixed_5c) # (1024, 8, 7, 7) def forward(self, x): if isinstance(x, dict): x = x['video'] x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) return x