fpn_fusion.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # The implementation here is modified based on timm,
  2. # originally Apache 2.0 License and publicly available at
  3. # https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/fpn_fusion.py
  4. import torch.nn as nn
  5. class FPNFusionModule(nn.Module):
  6. """ This is a fpn-style cross-scale feature fusion module" """
  7. def __init__(self, embed_dims, fuse_dim=256, n_block=4, use_bn=False):
  8. super().__init__()
  9. """ Initializes the model.
  10. Args:
  11. embed_dims: the list of channel dim for different scale feature maps (i.e., the input)
  12. fuse_dim: the channel dim of the fused feature map (i.e., the output)
  13. n_block: the number of multi-scale features (default=4)
  14. use_bn: whether to use bn
  15. """
  16. self.embed_dims = embed_dims
  17. self.fuse_dim = fuse_dim
  18. self.n_block = n_block
  19. # cross-scale fusion layers
  20. self.multi_scaler = _make_multi_scale_layers(
  21. embed_dims, fuse_dim, use_bn=use_bn, n_block=n_block)
  22. def forward(self, x_blocks):
  23. x_blocks = x_blocks
  24. # preparation: channel reduction and normalization
  25. for idx in range(self.n_block - 1, -1, -1):
  26. x_blocks[idx] = getattr(self.multi_scaler, f'layer_{idx}_rn')(
  27. x_blocks[idx])
  28. x_blocks[idx] = getattr(self.multi_scaler, f'p_norm_{idx}')(
  29. x_blocks[idx])
  30. # cross-scale fusion
  31. refined_embeds = []
  32. for idx in range(self.n_block - 1, -1, -1):
  33. if idx == self.n_block - 1:
  34. path = getattr(self.multi_scaler,
  35. f'refinenet_{idx}')([x_blocks[idx]], None)
  36. else:
  37. path = getattr(self.multi_scaler,
  38. f'refinenet_{idx}')([path, x_blocks[idx]],
  39. x_blocks[idx].size()[2:])
  40. refined_embeds.append(path)
  41. return refined_embeds
  42. def _make_multi_scale_layers(in_shape,
  43. out_shape,
  44. n_block=4,
  45. groups=1,
  46. use_bn=False):
  47. out_shapes = [out_shape for _ in range(n_block)]
  48. multi_scaler = nn.Module()
  49. for idx in range(n_block - 1, -1, -1):
  50. """
  51. 1 x 1 conv for dim reduction -> group norm
  52. """
  53. layer_name = f'layer_{(idx)}_rn'
  54. multi_scaler.add_module(
  55. layer_name,
  56. nn.Conv2d(in_shape[idx], out_shapes[idx], kernel_size=1))
  57. layer_name = f'p_norm_{(idx)}'
  58. multi_scaler.add_module(layer_name, nn.GroupNorm(32, out_shapes[idx]))
  59. layer_name = f'refinenet_{idx}'
  60. multi_scaler.add_module(layer_name,
  61. _make_fusion_block(out_shape, use_bn))
  62. # initialize for the 1x1 conv
  63. nn.init.xavier_uniform_(
  64. getattr(multi_scaler, f'layer_{idx}_rn').weight, gain=1)
  65. nn.init.constant_(getattr(multi_scaler, f'layer_{idx}_rn').bias, 0)
  66. return multi_scaler
  67. def _make_fusion_block(features, use_bn):
  68. """ We use a resnet bottleneck structure for fpn """
  69. return FeatureFusionBlock(
  70. features,
  71. nn.ReLU(False),
  72. bn=use_bn,
  73. expand=False,
  74. align_corners=True,
  75. )
  76. class FeatureFusionBlock(nn.Module):
  77. """ Feature fusion block """
  78. def __init__(self,
  79. features,
  80. activation,
  81. bn=False,
  82. expand=False,
  83. align_corners=True):
  84. """Init.
  85. Args:
  86. features (int): channel dim of the input feature
  87. activation: activation function to use
  88. bn: whether to use bn
  89. expand: whether to expand feature or not
  90. align_corners: whether to use align_corners for interpolation
  91. """
  92. super(FeatureFusionBlock, self).__init__()
  93. self.align_corners = align_corners
  94. self.groups = 1
  95. self.expand = expand
  96. out_features = features
  97. if self.expand is True:
  98. out_features = features // 2
  99. self.smoothing = nn.Conv2d(
  100. features,
  101. out_features,
  102. kernel_size=1,
  103. bias=True,
  104. groups=1,
  105. )
  106. self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
  107. self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
  108. self.skip_add = nn.quantized.FloatFunctional()
  109. def forward(self, xs, up_size):
  110. """ Forward pass.
  111. Args
  112. xs: xs[0]: the feature refined from the previous step, xs[1]: the next scale features to fuse
  113. up_size: the size for upsampling; xs[0] is upsampled before merging with xs[1]
  114. Returns:
  115. output: the fused feature, which is fed to the next fusion step as an input
  116. """
  117. output = xs[0]
  118. if len(xs) == 2:
  119. # upsampling
  120. output = nn.functional.interpolate(
  121. output,
  122. size=up_size,
  123. mode='bilinear',
  124. align_corners=self.align_corners)
  125. # feature smoothing since the upsampled feature is coarse-grain
  126. output = self.smoothing(output)
  127. # refine the next scale feature before fusion
  128. res = self.resConfUnit1(xs[1])
  129. # fusion
  130. output = self.skip_add.add(output, res)
  131. # post refine after fusion
  132. output = self.resConfUnit2(output)
  133. return output
  134. class ResidualConvUnit(nn.Module):
  135. """ Residual convolution module. """
  136. def __init__(self, features, activation, bn):
  137. """Init.
  138. Args:
  139. features (int): channel dim of the input
  140. activation: activation function
  141. bn: whether to use bn
  142. """
  143. super().__init__()
  144. self.bn = bn
  145. self.groups = 1
  146. self.conv1 = nn.Conv2d(
  147. features,
  148. 64,
  149. kernel_size=1,
  150. stride=1,
  151. bias=not self.bn,
  152. groups=self.groups,
  153. )
  154. self.conv2 = nn.Conv2d(
  155. 64,
  156. 64,
  157. kernel_size=3,
  158. stride=1,
  159. padding=1,
  160. bias=not self.bn,
  161. groups=self.groups,
  162. )
  163. self.conv3 = nn.Conv2d(
  164. 64,
  165. features,
  166. kernel_size=1,
  167. stride=1,
  168. bias=not self.bn,
  169. groups=self.groups,
  170. )
  171. if self.bn is True:
  172. self.bn1 = nn.BatchNorm2d(features)
  173. self.bn2 = nn.BatchNorm2d(features)
  174. self.bn3 = nn.BatchNorm2d(features)
  175. self.activation = activation
  176. self.skip_add = nn.quantized.FloatFunctional()
  177. def forward(self, x):
  178. """ Forward pass
  179. Args:
  180. x (tensor): input feature
  181. Returns:
  182. tensor: output feature
  183. """
  184. out = self.activation(x)
  185. out = self.conv1(out)
  186. if self.bn is True:
  187. out = self.bn1(out)
  188. out = self.activation(out)
  189. out = self.conv2(out)
  190. if self.bn is True:
  191. out = self.bn2(out)
  192. out = self.activation(out)
  193. out = self.conv3(out)
  194. if self.bn is True:
  195. out = self.bn3(out)
  196. if self.groups > 1:
  197. out = self.conv_merge(out)
  198. return self.skip_add.add(out, x)