resnet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2022 OFA-Sys Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. def drop_path(x, drop_prob: float = 0., training: bool = False):
  17. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  18. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  19. the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper...
  20. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  21. changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use
  22. 'survival rate' as the argument.
  23. """
  24. if drop_prob == 0. or not training:
  25. return x
  26. keep_prob = 1 - drop_prob
  27. shape = (x.shape[0], ) + (1, ) * (
  28. x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  29. random_tensor = keep_prob + torch.rand(
  30. shape, dtype=x.dtype, device=x.device)
  31. random_tensor.floor_() # binarize
  32. output = x.div(keep_prob) * random_tensor
  33. return output
  34. class DropPath(nn.Module):
  35. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  36. """
  37. def __init__(self, drop_prob=None):
  38. super(DropPath, self).__init__()
  39. self.drop_prob = drop_prob
  40. def forward(self, x):
  41. return drop_path(x, self.drop_prob, self.training)
  42. def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  43. """3x3 convolution with padding"""
  44. return nn.Conv2d(
  45. in_planes,
  46. out_planes,
  47. kernel_size=3,
  48. stride=stride,
  49. padding=dilation,
  50. groups=groups,
  51. bias=False,
  52. dilation=dilation)
  53. def conv1x1(in_planes, out_planes, stride=1):
  54. """1x1 convolution"""
  55. return nn.Conv2d(
  56. in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  57. class BasicBlock(nn.Module):
  58. expansion = 1
  59. def __init__(self,
  60. inplanes,
  61. planes,
  62. stride=1,
  63. downsample=None,
  64. groups=1,
  65. base_width=64,
  66. dilation=1,
  67. norm_layer=None):
  68. super(BasicBlock, self).__init__()
  69. if norm_layer is None:
  70. norm_layer = nn.BatchNorm2d
  71. if groups != 1 or base_width != 64:
  72. raise ValueError(
  73. 'BasicBlock only supports groups=1 and base_width=64')
  74. if dilation > 1:
  75. raise NotImplementedError(
  76. 'Dilation > 1 not supported in BasicBlock')
  77. # Both self.conv1 and self.downsample layers downsample the input when stride != 1
  78. self.conv1 = conv3x3(inplanes, planes, stride)
  79. self.bn1 = norm_layer(planes)
  80. self.relu = nn.ReLU(inplace=True)
  81. self.conv2 = conv3x3(planes, planes)
  82. self.bn2 = norm_layer(planes)
  83. self.downsample = downsample
  84. self.stride = stride
  85. def forward(self, x):
  86. assert False
  87. identity = x
  88. out = self.conv1(x)
  89. out = self.bn1(out)
  90. out = self.relu(out)
  91. out = self.conv2(out)
  92. out = self.bn2(out)
  93. if self.downsample is not None:
  94. identity = self.downsample(x)
  95. out += identity
  96. out = self.relu(out)
  97. return out
  98. class Bottleneck(nn.Module):
  99. # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
  100. # while original implementation places the stride at the first 1x1 convolution(self.conv1)
  101. # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
  102. # This variant is also known as ResNet V1.5 and improves accuracy according to
  103. # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
  104. expansion = 4
  105. def __init__(self,
  106. inplanes,
  107. planes,
  108. stride=1,
  109. downsample=None,
  110. groups=1,
  111. base_width=64,
  112. dilation=1,
  113. norm_layer=None,
  114. drop_path_rate=0.0):
  115. super(Bottleneck, self).__init__()
  116. if norm_layer is None:
  117. norm_layer = nn.BatchNorm2d
  118. width = int(planes * (base_width / 64.)) * groups
  119. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  120. self.conv1 = conv1x1(inplanes, width)
  121. self.bn1 = norm_layer(width)
  122. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  123. self.bn2 = norm_layer(width)
  124. self.conv3 = conv1x1(width, planes * self.expansion)
  125. self.bn3 = norm_layer(planes * self.expansion)
  126. self.relu = nn.ReLU(inplace=True)
  127. self.downsample = downsample
  128. self.stride = stride
  129. self.drop_path = DropPath(
  130. drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  131. def forward(self, x):
  132. identity = x
  133. out = self.conv1(x)
  134. out = self.bn1(out)
  135. out = self.relu(out)
  136. out = self.conv2(out)
  137. out = self.bn2(out)
  138. out = self.relu(out)
  139. out = self.conv3(out)
  140. out = self.bn3(out)
  141. if self.downsample is not None:
  142. identity = self.downsample(x)
  143. out = identity + self.drop_path(out)
  144. out = self.relu(out)
  145. return out
  146. class ResNet(nn.Module):
  147. r"""
  148. Deep residual network, copy from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py.
  149. You can see more details from https://arxiv.org/abs/1512.03385
  150. step 1. Get image embedding with `7` as the patch image size, `2` as stride.
  151. step 2. Do layer normalization, relu activation and max pooling.
  152. step 3. Go through three times residual branch.
  153. """
  154. def __init__(self,
  155. layers,
  156. zero_init_residual=False,
  157. groups=1,
  158. width_per_group=64,
  159. replace_stride_with_dilation=None,
  160. norm_layer=None,
  161. drop_path_rate=0.0):
  162. r"""
  163. Args:
  164. layers (`Tuple[int]`): There are three layers in resnet, so the length
  165. of layers should greater then three. And each element in `layers` is
  166. the number of `Bottleneck` in relative residual branch.
  167. zero_init_residual (`bool`, **optional**, default to `False`):
  168. Whether or not to zero-initialize the last BN in each residual branch.
  169. groups (`int`, **optional**, default to `1`):
  170. The number of groups. So far, only the value of `1` is supported.
  171. width_per_group (`int`, **optional**, default to `64`):
  172. The width in each group. So far, only the value of `64` is supported.
  173. replace_stride_with_dilation (`Tuple[bool]`, **optional**, default to `None`):
  174. Whether or not to replace stride with dilation in each residual branch.
  175. norm_layer (`torch.nn.Module`, **optional**, default to `None`):
  176. The normalization module. If `None`, will use `torch.nn.BatchNorm2d`.
  177. drop_path_rate (`float`, **optional**, default to 0.0):
  178. Drop path rate. See more details about drop path from
  179. https://arxiv.org/pdf/1605.07648v4.pdf.
  180. """
  181. super(ResNet, self).__init__()
  182. if norm_layer is None:
  183. norm_layer = nn.BatchNorm2d
  184. self._norm_layer = norm_layer
  185. self.inplanes = 64
  186. self.dilation = 1
  187. if replace_stride_with_dilation is None:
  188. # each element in the tuple indicates if we should replace
  189. # the 2x2 stride with a dilated convolution instead
  190. replace_stride_with_dilation = [False, False, False]
  191. if len(replace_stride_with_dilation) != 3:
  192. raise ValueError('replace_stride_with_dilation should be None '
  193. 'or a 3-element tuple, got {}'.format(
  194. replace_stride_with_dilation))
  195. self.groups = groups
  196. self.base_width = width_per_group
  197. self.conv1 = nn.Conv2d(
  198. 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
  199. self.bn1 = norm_layer(self.inplanes)
  200. self.relu = nn.ReLU(inplace=True)
  201. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  202. self.layer1 = self._make_layer(
  203. Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate)
  204. self.layer2 = self._make_layer(
  205. Bottleneck,
  206. 128,
  207. layers[1],
  208. stride=2,
  209. dilate=replace_stride_with_dilation[0],
  210. drop_path_rate=drop_path_rate)
  211. self.layer3 = self._make_layer(
  212. Bottleneck,
  213. 256,
  214. layers[2],
  215. stride=2,
  216. dilate=replace_stride_with_dilation[1],
  217. drop_path_rate=drop_path_rate)
  218. for m in self.modules():
  219. if isinstance(m, nn.Conv2d):
  220. nn.init.kaiming_normal_(
  221. m.weight, mode='fan_out', nonlinearity='relu')
  222. elif isinstance(m,
  223. (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)):
  224. nn.init.constant_(m.weight, 1)
  225. nn.init.constant_(m.bias, 0)
  226. # Zero-initialize the last BN in each residual branch,
  227. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  228. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  229. if zero_init_residual:
  230. for m in self.modules():
  231. if isinstance(m, Bottleneck):
  232. nn.init.constant_(m.bn3.weight, 0)
  233. elif isinstance(m, BasicBlock):
  234. nn.init.constant_(m.bn2.weight, 0)
  235. def _make_layer(self,
  236. block,
  237. planes,
  238. blocks,
  239. stride=1,
  240. dilate=False,
  241. drop_path_rate=0.0):
  242. r"""
  243. Making a single residual branch.
  244. step 1. If dilate==`True`, switch the value of dilate and stride.
  245. step 2. If the input dimension doesn't equal to th output output dimension
  246. in `block`, initialize a down sample module.
  247. step 3. Build a sequential of `blocks` number of `block`.
  248. Args:
  249. block (`torch.nn.Module`): The basic block in residual branch.
  250. planes (`int`): The output dimension of each basic block.
  251. blocks (`int`): The number of `block` in residual branch.
  252. stride (`int`, **optional**, default to `1`):
  253. The stride using in conv.
  254. dilate (`bool`, **optional**, default to `False`):
  255. Whether or not to replace dilate with stride.
  256. drop_path_rate (`float`, **optional**, default to 0.0):
  257. Drop path rate. See more details about drop path from
  258. https://arxiv.org/pdf/1605.07648v4.pdf.
  259. Returns:
  260. A sequential of basic layer with type `torch.nn.Sequential[block]`
  261. """
  262. norm_layer = self._norm_layer
  263. downsample = None
  264. previous_dilation = self.dilation
  265. if dilate:
  266. self.dilation *= stride
  267. stride = 1
  268. if stride != 1 or self.inplanes != planes * block.expansion:
  269. downsample = nn.Sequential(
  270. conv1x1(self.inplanes, planes * block.expansion, stride),
  271. norm_layer(planes * block.expansion),
  272. )
  273. layers = []
  274. layers.append(
  275. block(self.inplanes, planes, stride, downsample, self.groups,
  276. self.base_width, previous_dilation, norm_layer))
  277. self.inplanes = planes * block.expansion
  278. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)]
  279. for i in range(1, blocks):
  280. layers.append(
  281. block(
  282. self.inplanes,
  283. planes,
  284. groups=self.groups,
  285. base_width=self.base_width,
  286. dilation=self.dilation,
  287. norm_layer=norm_layer,
  288. drop_path_rate=dpr[i]))
  289. return nn.Sequential(*layers)
  290. def _forward_impl(self, x):
  291. x = self.conv1(x)
  292. x = self.bn1(x)
  293. x = self.relu(x)
  294. x = self.maxpool(x)
  295. x = self.layer1(x)
  296. x = self.layer2(x)
  297. x = self.layer3(x)
  298. return x
  299. def forward(self, x):
  300. return self._forward_impl(x)