rec_resnet_45.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/FangShancheng/ABINet/tree/main/modules
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. from paddle import ParamAttr
  23. from paddle.nn.initializer import KaimingNormal
  24. import paddle.nn as nn
  25. import paddle.nn.functional as F
  26. import numpy as np
  27. import math
  28. __all__ = ["ResNet45"]
  29. def conv1x1(in_planes, out_planes, stride=1):
  30. return nn.Conv2D(
  31. in_planes,
  32. out_planes,
  33. kernel_size=1,
  34. stride=1,
  35. weight_attr=ParamAttr(initializer=KaimingNormal()),
  36. bias_attr=False,
  37. )
  38. def conv3x3(in_channel, out_channel, stride=1):
  39. return nn.Conv2D(
  40. in_channel,
  41. out_channel,
  42. kernel_size=3,
  43. stride=stride,
  44. padding=1,
  45. weight_attr=ParamAttr(initializer=KaimingNormal()),
  46. bias_attr=False,
  47. )
  48. class BasicBlock(nn.Layer):
  49. expansion = 1
  50. def __init__(self, in_channels, channels, stride=1, downsample=None):
  51. super().__init__()
  52. self.conv1 = conv1x1(in_channels, channels)
  53. self.bn1 = nn.BatchNorm2D(channels)
  54. self.relu = nn.ReLU()
  55. self.conv2 = conv3x3(channels, channels, stride)
  56. self.bn2 = nn.BatchNorm2D(channels)
  57. self.downsample = downsample
  58. self.stride = stride
  59. def forward(self, x):
  60. residual = x
  61. out = self.conv1(x)
  62. out = self.bn1(out)
  63. out = self.relu(out)
  64. out = self.conv2(out)
  65. out = self.bn2(out)
  66. if self.downsample is not None:
  67. residual = self.downsample(x)
  68. out += residual
  69. out = self.relu(out)
  70. return out
  71. class ResNet45(nn.Layer):
  72. def __init__(
  73. self,
  74. in_channels=3,
  75. block=BasicBlock,
  76. layers=[3, 4, 6, 6, 3],
  77. strides=[2, 1, 2, 1, 1],
  78. ):
  79. self.inplanes = 32
  80. super(ResNet45, self).__init__()
  81. self.conv1 = nn.Conv2D(
  82. in_channels,
  83. 32,
  84. kernel_size=3,
  85. stride=1,
  86. padding=1,
  87. weight_attr=ParamAttr(initializer=KaimingNormal()),
  88. bias_attr=False,
  89. )
  90. self.bn1 = nn.BatchNorm2D(32)
  91. self.relu = nn.ReLU()
  92. self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
  93. self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
  94. self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
  95. self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
  96. self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
  97. self.out_channels = 512
  98. def _make_layer(self, block, planes, blocks, stride=1):
  99. downsample = None
  100. if stride != 1 or self.inplanes != planes * block.expansion:
  101. # downsample = True
  102. downsample = nn.Sequential(
  103. nn.Conv2D(
  104. self.inplanes,
  105. planes * block.expansion,
  106. kernel_size=1,
  107. stride=stride,
  108. weight_attr=ParamAttr(initializer=KaimingNormal()),
  109. bias_attr=False,
  110. ),
  111. nn.BatchNorm2D(planes * block.expansion),
  112. )
  113. layers = []
  114. layers.append(block(self.inplanes, planes, stride, downsample))
  115. self.inplanes = planes * block.expansion
  116. for i in range(1, blocks):
  117. layers.append(block(self.inplanes, planes))
  118. return nn.Sequential(*layers)
  119. def forward(self, x):
  120. x = self.conv1(x)
  121. x = self.bn1(x)
  122. x = self.relu(x)
  123. x = self.layer1(x)
  124. x = self.layer2(x)
  125. x = self.layer3(x)
  126. x = self.layer4(x)
  127. x = self.layer5(x)
  128. return x