splat.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # The implementation is adopted from Split-Attention Network, A New ResNet Variant,
  2. # made publicly available under the Apache License 2.0 License
  3. # at https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/models/splat.py
  4. """Split-Attention"""
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn
  8. from torch.nn import BatchNorm2d, Conv2d, Linear, Module, ReLU
  9. from torch.nn.modules.utils import _pair
  10. __all__ = ['SplAtConv2d']
  11. class SplAtConv2d(Module):
  12. """Split-Attention Conv2d
  13. """
  14. def __init__(self,
  15. in_channels,
  16. channels,
  17. kernel_size,
  18. stride=(1, 1),
  19. padding=(0, 0),
  20. dilation=(1, 1),
  21. groups=1,
  22. bias=True,
  23. radix=2,
  24. reduction_factor=4,
  25. rectify=False,
  26. rectify_avg=False,
  27. norm_layer=None,
  28. dropblock_prob=0.0,
  29. **kwargs):
  30. super(SplAtConv2d, self).__init__()
  31. padding = _pair(padding)
  32. self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
  33. self.rectify_avg = rectify_avg
  34. inter_channels = max(in_channels * radix // reduction_factor, 32)
  35. self.radix = radix
  36. self.cardinality = groups
  37. self.channels = channels
  38. self.dropblock_prob = dropblock_prob
  39. if self.rectify:
  40. self.conv = Conv2d(
  41. in_channels,
  42. channels * radix,
  43. kernel_size,
  44. stride,
  45. padding,
  46. dilation,
  47. groups=groups * radix,
  48. bias=bias,
  49. **kwargs)
  50. else:
  51. self.conv = Conv2d(
  52. in_channels,
  53. channels * radix,
  54. kernel_size,
  55. stride,
  56. padding,
  57. dilation,
  58. groups=groups * radix,
  59. bias=bias,
  60. **kwargs)
  61. self.use_bn = norm_layer is not None
  62. if self.use_bn:
  63. self.bn0 = norm_layer(channels * radix)
  64. self.relu = ReLU(inplace=True)
  65. self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
  66. if self.use_bn:
  67. self.bn1 = norm_layer(inter_channels)
  68. self.fc2 = Conv2d(
  69. inter_channels, channels * radix, 1, groups=self.cardinality)
  70. if dropblock_prob > 0.0:
  71. self.dropblock = DropBlock2D(dropblock_prob, 3)
  72. self.rsoftmax = rSoftMax(radix, groups)
  73. def forward(self, x):
  74. x = self.conv(x)
  75. if self.use_bn:
  76. x = self.bn0(x)
  77. if self.dropblock_prob > 0.0:
  78. x = self.dropblock(x)
  79. x = self.relu(x)
  80. batch, rchannel = x.shape[:2]
  81. if self.radix > 1:
  82. splited = torch.split(x, rchannel // self.radix, dim=1)
  83. gap = sum(splited)
  84. else:
  85. gap = x
  86. gap = F.adaptive_avg_pool2d(gap, 1)
  87. gap = self.fc1(gap)
  88. if self.use_bn:
  89. gap = self.bn1(gap)
  90. gap = self.relu(gap)
  91. atten = self.fc2(gap)
  92. atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
  93. if self.radix > 1:
  94. attens = torch.split(atten, rchannel // self.radix, dim=1)
  95. out = sum([att * split for (att, split) in zip(attens, splited)])
  96. else:
  97. out = atten * x
  98. return out.contiguous()
  99. class rSoftMax(nn.Module):
  100. def __init__(self, radix, cardinality):
  101. super().__init__()
  102. self.radix = radix
  103. self.cardinality = cardinality
  104. def forward(self, x):
  105. batch = x.size(0)
  106. if self.radix > 1:
  107. x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
  108. x = F.softmax(x, dim=1)
  109. x = x.reshape(batch, -1)
  110. else:
  111. x = torch.sigmoid(x)
  112. return x