| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- # The implementation here is modified based on timm,
- # originally Apache 2.0 License and publicly available at
- # https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/fpn_fusion.py
- import torch.nn as nn
- class FPNFusionModule(nn.Module):
- """ This is a fpn-style cross-scale feature fusion module" """
- def __init__(self, embed_dims, fuse_dim=256, n_block=4, use_bn=False):
- super().__init__()
- """ Initializes the model.
- Args:
- embed_dims: the list of channel dim for different scale feature maps (i.e., the input)
- fuse_dim: the channel dim of the fused feature map (i.e., the output)
- n_block: the number of multi-scale features (default=4)
- use_bn: whether to use bn
- """
- self.embed_dims = embed_dims
- self.fuse_dim = fuse_dim
- self.n_block = n_block
- # cross-scale fusion layers
- self.multi_scaler = _make_multi_scale_layers(
- embed_dims, fuse_dim, use_bn=use_bn, n_block=n_block)
- def forward(self, x_blocks):
- x_blocks = x_blocks
- # preparation: channel reduction and normalization
- for idx in range(self.n_block - 1, -1, -1):
- x_blocks[idx] = getattr(self.multi_scaler, f'layer_{idx}_rn')(
- x_blocks[idx])
- x_blocks[idx] = getattr(self.multi_scaler, f'p_norm_{idx}')(
- x_blocks[idx])
- # cross-scale fusion
- refined_embeds = []
- for idx in range(self.n_block - 1, -1, -1):
- if idx == self.n_block - 1:
- path = getattr(self.multi_scaler,
- f'refinenet_{idx}')([x_blocks[idx]], None)
- else:
- path = getattr(self.multi_scaler,
- f'refinenet_{idx}')([path, x_blocks[idx]],
- x_blocks[idx].size()[2:])
- refined_embeds.append(path)
- return refined_embeds
- def _make_multi_scale_layers(in_shape,
- out_shape,
- n_block=4,
- groups=1,
- use_bn=False):
- out_shapes = [out_shape for _ in range(n_block)]
- multi_scaler = nn.Module()
- for idx in range(n_block - 1, -1, -1):
- """
- 1 x 1 conv for dim reduction -> group norm
- """
- layer_name = f'layer_{(idx)}_rn'
- multi_scaler.add_module(
- layer_name,
- nn.Conv2d(in_shape[idx], out_shapes[idx], kernel_size=1))
- layer_name = f'p_norm_{(idx)}'
- multi_scaler.add_module(layer_name, nn.GroupNorm(32, out_shapes[idx]))
- layer_name = f'refinenet_{idx}'
- multi_scaler.add_module(layer_name,
- _make_fusion_block(out_shape, use_bn))
- # initialize for the 1x1 conv
- nn.init.xavier_uniform_(
- getattr(multi_scaler, f'layer_{idx}_rn').weight, gain=1)
- nn.init.constant_(getattr(multi_scaler, f'layer_{idx}_rn').bias, 0)
- return multi_scaler
- def _make_fusion_block(features, use_bn):
- """ We use a resnet bottleneck structure for fpn """
- return FeatureFusionBlock(
- features,
- nn.ReLU(False),
- bn=use_bn,
- expand=False,
- align_corners=True,
- )
- class FeatureFusionBlock(nn.Module):
- """ Feature fusion block """
- def __init__(self,
- features,
- activation,
- bn=False,
- expand=False,
- align_corners=True):
- """Init.
- Args:
- features (int): channel dim of the input feature
- activation: activation function to use
- bn: whether to use bn
- expand: whether to expand feature or not
- align_corners: whether to use align_corners for interpolation
- """
- super(FeatureFusionBlock, self).__init__()
- self.align_corners = align_corners
- self.groups = 1
- self.expand = expand
- out_features = features
- if self.expand is True:
- out_features = features // 2
- self.smoothing = nn.Conv2d(
- features,
- out_features,
- kernel_size=1,
- bias=True,
- groups=1,
- )
- self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
- self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
- self.skip_add = nn.quantized.FloatFunctional()
- def forward(self, xs, up_size):
- """ Forward pass.
- Args
- xs: xs[0]: the feature refined from the previous step, xs[1]: the next scale features to fuse
- up_size: the size for upsampling; xs[0] is upsampled before merging with xs[1]
- Returns:
- output: the fused feature, which is fed to the next fusion step as an input
- """
- output = xs[0]
- if len(xs) == 2:
- # upsampling
- output = nn.functional.interpolate(
- output,
- size=up_size,
- mode='bilinear',
- align_corners=self.align_corners)
- # feature smoothing since the upsampled feature is coarse-grain
- output = self.smoothing(output)
- # refine the next scale feature before fusion
- res = self.resConfUnit1(xs[1])
- # fusion
- output = self.skip_add.add(output, res)
- # post refine after fusion
- output = self.resConfUnit2(output)
- return output
- class ResidualConvUnit(nn.Module):
- """ Residual convolution module. """
- def __init__(self, features, activation, bn):
- """Init.
- Args:
- features (int): channel dim of the input
- activation: activation function
- bn: whether to use bn
- """
- super().__init__()
- self.bn = bn
- self.groups = 1
- self.conv1 = nn.Conv2d(
- features,
- 64,
- kernel_size=1,
- stride=1,
- bias=not self.bn,
- groups=self.groups,
- )
- self.conv2 = nn.Conv2d(
- 64,
- 64,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=not self.bn,
- groups=self.groups,
- )
- self.conv3 = nn.Conv2d(
- 64,
- features,
- kernel_size=1,
- stride=1,
- bias=not self.bn,
- groups=self.groups,
- )
- if self.bn is True:
- self.bn1 = nn.BatchNorm2d(features)
- self.bn2 = nn.BatchNorm2d(features)
- self.bn3 = nn.BatchNorm2d(features)
- self.activation = activation
- self.skip_add = nn.quantized.FloatFunctional()
- def forward(self, x):
- """ Forward pass
- Args:
- x (tensor): input feature
- Returns:
- tensor: output feature
- """
- out = self.activation(x)
- out = self.conv1(out)
- if self.bn is True:
- out = self.bn1(out)
- out = self.activation(out)
- out = self.conv2(out)
- if self.bn is True:
- out = self.bn2(out)
- out = self.activation(out)
- out = self.conv3(out)
- if self.bn is True:
- out = self.bn3(out)
- if self.groups > 1:
- out = self.conv_merge(out)
- return self.skip_add.add(out, x)
|