resnet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # The implementation is adopted from Split-Attention Network, A New ResNet Variant,
  2. # made publicly available under the Apache License 2.0 License
  3. # at https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/models/resnet.py
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. from .splat import SplAtConv2d
  8. __all__ = ['ResNet', 'Bottleneck']
  9. class DropBlock2D(object):
  10. def __init__(self, *args, **kwargs):
  11. raise NotImplementedError
  12. class GlobalAvgPool2d(nn.Module):
  13. def __init__(self):
  14. """Global average pooling over the input's spatial dimensions"""
  15. super(GlobalAvgPool2d, self).__init__()
  16. def forward(self, inputs):
  17. return nn.functional.adaptive_avg_pool2d(inputs,
  18. 1).view(inputs.size(0), -1)
  19. class Bottleneck(nn.Module):
  20. expansion = 4
  21. def __init__(self,
  22. inplanes,
  23. planes,
  24. stride=1,
  25. downsample=None,
  26. radix=1,
  27. cardinality=1,
  28. bottleneck_width=64,
  29. avd=False,
  30. avd_first=False,
  31. dilation=1,
  32. is_first=False,
  33. rectified_conv=False,
  34. rectify_avg=False,
  35. norm_layer=None,
  36. dropblock_prob=0.0,
  37. last_gamma=False):
  38. super(Bottleneck, self).__init__()
  39. group_width = int(planes * (bottleneck_width / 64.)) * cardinality
  40. self.conv1 = nn.Conv2d(
  41. inplanes, group_width, kernel_size=1, bias=False)
  42. self.bn1 = norm_layer(group_width)
  43. self.dropblock_prob = dropblock_prob
  44. self.radix = radix
  45. self.avd = avd and (stride > 1 or is_first)
  46. self.avd_first = avd_first
  47. if self.avd:
  48. self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
  49. stride = 1
  50. if dropblock_prob > 0.0:
  51. self.dropblock1 = DropBlock2D(dropblock_prob, 3)
  52. if radix == 1:
  53. self.dropblock2 = DropBlock2D(dropblock_prob, 3)
  54. self.dropblock3 = DropBlock2D(dropblock_prob, 3)
  55. if radix >= 1:
  56. self.conv2 = SplAtConv2d(
  57. group_width,
  58. group_width,
  59. kernel_size=3,
  60. stride=stride,
  61. padding=dilation,
  62. dilation=dilation,
  63. groups=cardinality,
  64. bias=False,
  65. radix=radix,
  66. rectify=rectified_conv,
  67. rectify_avg=rectify_avg,
  68. norm_layer=norm_layer,
  69. dropblock_prob=dropblock_prob)
  70. elif rectified_conv:
  71. self.conv2 = nn.Conv2d(
  72. group_width,
  73. group_width,
  74. kernel_size=3,
  75. stride=stride,
  76. padding=dilation,
  77. dilation=dilation,
  78. groups=cardinality,
  79. bias=False)
  80. self.bn2 = norm_layer(group_width)
  81. else:
  82. self.conv2 = nn.Conv2d(
  83. group_width,
  84. group_width,
  85. kernel_size=3,
  86. stride=stride,
  87. padding=dilation,
  88. dilation=dilation,
  89. groups=cardinality,
  90. bias=False)
  91. self.bn2 = norm_layer(group_width)
  92. self.conv3 = nn.Conv2d(
  93. group_width, planes * 4, kernel_size=1, bias=False)
  94. self.bn3 = norm_layer(planes * 4)
  95. if last_gamma:
  96. from torch.nn.init import zeros_
  97. zeros_(self.bn3.weight)
  98. self.relu = nn.ReLU(inplace=True)
  99. self.downsample = downsample
  100. self.dilation = dilation
  101. self.stride = stride
  102. def forward(self, x):
  103. residual = x
  104. out = self.conv1(x)
  105. out = self.bn1(out)
  106. if self.dropblock_prob > 0.0:
  107. out = self.dropblock1(out)
  108. out = self.relu(out)
  109. if self.avd and self.avd_first:
  110. out = self.avd_layer(out)
  111. out = self.conv2(out)
  112. if self.radix == 0:
  113. out = self.bn2(out)
  114. if self.dropblock_prob > 0.0:
  115. out = self.dropblock2(out)
  116. out = self.relu(out)
  117. if self.avd and not self.avd_first:
  118. out = self.avd_layer(out)
  119. out = self.conv3(out)
  120. out = self.bn3(out)
  121. if self.dropblock_prob > 0.0:
  122. out = self.dropblock3(out)
  123. if self.downsample is not None:
  124. residual = self.downsample(x)
  125. out += residual
  126. out = self.relu(out)
  127. return out
  128. class ResNet(nn.Module):
  129. def __init__(self,
  130. block,
  131. layers,
  132. radix=1,
  133. groups=1,
  134. bottleneck_width=64,
  135. num_classes=1000,
  136. dilated=False,
  137. dilation=1,
  138. deep_stem=False,
  139. stem_width=64,
  140. avg_down=False,
  141. rectified_conv=False,
  142. rectify_avg=False,
  143. avd=False,
  144. avd_first=False,
  145. final_drop=0.0,
  146. dropblock_prob=0,
  147. last_gamma=False,
  148. norm_layer=nn.BatchNorm2d):
  149. self.cardinality = groups
  150. self.bottleneck_width = bottleneck_width
  151. # ResNet-D params
  152. self.inplanes = stem_width * 2 if deep_stem else 64
  153. self.avg_down = avg_down
  154. self.last_gamma = last_gamma
  155. # ResNeSt params
  156. self.radix = radix
  157. self.avd = avd
  158. self.avd_first = avd_first
  159. super(ResNet, self).__init__()
  160. self.rectified_conv = rectified_conv
  161. self.rectify_avg = rectify_avg
  162. if rectified_conv:
  163. conv_layer = nn.Conv2d
  164. else:
  165. conv_layer = nn.Conv2d
  166. conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
  167. if deep_stem:
  168. self.conv1 = nn.Sequential(
  169. conv_layer(
  170. 3,
  171. stem_width,
  172. kernel_size=3,
  173. stride=2,
  174. padding=1,
  175. bias=False,
  176. **conv_kwargs),
  177. norm_layer(stem_width),
  178. nn.ReLU(inplace=True),
  179. conv_layer(
  180. stem_width,
  181. stem_width,
  182. kernel_size=3,
  183. stride=1,
  184. padding=1,
  185. bias=False,
  186. **conv_kwargs),
  187. norm_layer(stem_width),
  188. nn.ReLU(inplace=True),
  189. conv_layer(
  190. stem_width,
  191. stem_width * 2,
  192. kernel_size=3,
  193. stride=1,
  194. padding=1,
  195. bias=False,
  196. **conv_kwargs),
  197. )
  198. else:
  199. self.conv1 = conv_layer(
  200. 3,
  201. 64,
  202. kernel_size=7,
  203. stride=2,
  204. padding=3,
  205. bias=False,
  206. **conv_kwargs)
  207. self.bn1 = norm_layer(self.inplanes)
  208. self.relu = nn.ReLU(inplace=True)
  209. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  210. self.layer1 = self._make_layer(
  211. block, 64, layers[0], norm_layer=norm_layer, is_first=False)
  212. self.layer2 = self._make_layer(
  213. block, 128, layers[1], stride=2, norm_layer=norm_layer)
  214. if dilated or dilation == 4:
  215. self.layer3 = self._make_layer(
  216. block,
  217. 256,
  218. layers[2],
  219. stride=1,
  220. dilation=2,
  221. norm_layer=norm_layer,
  222. dropblock_prob=dropblock_prob)
  223. self.layer4 = self._make_layer(
  224. block,
  225. 512,
  226. layers[3],
  227. stride=1,
  228. dilation=4,
  229. norm_layer=norm_layer,
  230. dropblock_prob=dropblock_prob)
  231. elif dilation == 2:
  232. self.layer3 = self._make_layer(
  233. block,
  234. 256,
  235. layers[2],
  236. stride=2,
  237. dilation=1,
  238. norm_layer=norm_layer,
  239. dropblock_prob=dropblock_prob)
  240. self.layer4 = self._make_layer(
  241. block,
  242. 512,
  243. layers[3],
  244. stride=1,
  245. dilation=2,
  246. norm_layer=norm_layer,
  247. dropblock_prob=dropblock_prob)
  248. else:
  249. self.layer3 = self._make_layer(
  250. block,
  251. 256,
  252. layers[2],
  253. stride=2,
  254. norm_layer=norm_layer,
  255. dropblock_prob=dropblock_prob)
  256. self.layer4 = self._make_layer(
  257. block,
  258. 512,
  259. layers[3],
  260. stride=2,
  261. norm_layer=norm_layer,
  262. dropblock_prob=dropblock_prob)
  263. self.avgpool = GlobalAvgPool2d()
  264. self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
  265. self.fc = nn.Linear(512 * block.expansion, num_classes)
  266. for m in self.modules():
  267. if isinstance(m, nn.Conv2d):
  268. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  269. m.weight.data.normal_(0, math.sqrt(2. / n))
  270. elif isinstance(m, norm_layer):
  271. m.weight.data.fill_(1)
  272. m.bias.data.zero_()
  273. def _make_layer(self,
  274. block,
  275. planes,
  276. blocks,
  277. stride=1,
  278. dilation=1,
  279. norm_layer=None,
  280. dropblock_prob=0.0,
  281. is_first=True):
  282. downsample = None
  283. if stride != 1 or self.inplanes != planes * block.expansion:
  284. down_layers = []
  285. if self.avg_down:
  286. if dilation == 1:
  287. down_layers.append(
  288. nn.AvgPool2d(
  289. kernel_size=stride,
  290. stride=stride,
  291. ceil_mode=True,
  292. count_include_pad=False))
  293. else:
  294. down_layers.append(
  295. nn.AvgPool2d(
  296. kernel_size=1,
  297. stride=1,
  298. ceil_mode=True,
  299. count_include_pad=False))
  300. down_layers.append(
  301. nn.Conv2d(
  302. self.inplanes,
  303. planes * block.expansion,
  304. kernel_size=1,
  305. stride=1,
  306. bias=False))
  307. else:
  308. down_layers.append(
  309. nn.Conv2d(
  310. self.inplanes,
  311. planes * block.expansion,
  312. kernel_size=1,
  313. stride=stride,
  314. bias=False))
  315. down_layers.append(norm_layer(planes * block.expansion))
  316. downsample = nn.Sequential(*down_layers)
  317. layers = []
  318. if dilation == 1 or dilation == 2:
  319. layers.append(
  320. block(
  321. self.inplanes,
  322. planes,
  323. stride,
  324. downsample=downsample,
  325. radix=self.radix,
  326. cardinality=self.cardinality,
  327. bottleneck_width=self.bottleneck_width,
  328. avd=self.avd,
  329. avd_first=self.avd_first,
  330. dilation=1,
  331. is_first=is_first,
  332. rectified_conv=self.rectified_conv,
  333. rectify_avg=self.rectify_avg,
  334. norm_layer=norm_layer,
  335. dropblock_prob=dropblock_prob,
  336. last_gamma=self.last_gamma))
  337. elif dilation == 4:
  338. layers.append(
  339. block(
  340. self.inplanes,
  341. planes,
  342. stride,
  343. downsample=downsample,
  344. radix=self.radix,
  345. cardinality=self.cardinality,
  346. bottleneck_width=self.bottleneck_width,
  347. avd=self.avd,
  348. avd_first=self.avd_first,
  349. dilation=2,
  350. is_first=is_first,
  351. rectified_conv=self.rectified_conv,
  352. rectify_avg=self.rectify_avg,
  353. norm_layer=norm_layer,
  354. dropblock_prob=dropblock_prob,
  355. last_gamma=self.last_gamma))
  356. else:
  357. raise RuntimeError('=> unknown dilation size: {}'.format(dilation))
  358. self.inplanes = planes * block.expansion
  359. for i in range(1, blocks):
  360. layers.append(
  361. block(
  362. self.inplanes,
  363. planes,
  364. radix=self.radix,
  365. cardinality=self.cardinality,
  366. bottleneck_width=self.bottleneck_width,
  367. avd=self.avd,
  368. avd_first=self.avd_first,
  369. dilation=dilation,
  370. rectified_conv=self.rectified_conv,
  371. rectify_avg=self.rectify_avg,
  372. norm_layer=norm_layer,
  373. dropblock_prob=dropblock_prob,
  374. last_gamma=self.last_gamma))
  375. return nn.Sequential(*layers)
  376. def forward(self, x):
  377. x = self.conv1(x)
  378. x = self.bn1(x)
  379. x = self.relu(x)
  380. x = self.maxpool(x)
  381. x = self.layer1(x)
  382. x = self.layer2(x)
  383. x = self.layer3(x)
  384. x = self.layer4(x)
  385. x = self.avgpool(x)
  386. x = torch.flatten(x, 1)
  387. if self.drop:
  388. x = self.drop(x)
  389. x = self.fc(x)
  390. return x