resnet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import math
  2. import paddle
  3. from paddle import nn
  4. BatchNorm2d = nn.BatchNorm2D
  5. __all__ = [
  6. "ResNet",
  7. "resnet18",
  8. "resnet34",
  9. "resnet50",
  10. "resnet101",
  11. "deformable_resnet18",
  12. "deformable_resnet50",
  13. "resnet152",
  14. ]
  15. model_urls = {
  16. "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
  17. "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
  18. "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
  19. "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
  20. "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
  21. }
  22. def constant_init(module, constant, bias=0):
  23. module.weight = paddle.create_parameter(
  24. shape=module.weight.shape,
  25. dtype="float32",
  26. default_initializer=paddle.nn.initializer.Constant(constant),
  27. )
  28. if hasattr(module, "bias"):
  29. module.bias = paddle.create_parameter(
  30. shape=module.bias.shape,
  31. dtype="float32",
  32. default_initializer=paddle.nn.initializer.Constant(bias),
  33. )
  34. def conv3x3(in_planes, out_planes, stride=1):
  35. """3x3 convolution with padding"""
  36. return nn.Conv2D(
  37. in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias_attr=False
  38. )
  39. class BasicBlock(nn.Layer):
  40. expansion = 1
  41. def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
  42. super(BasicBlock, self).__init__()
  43. self.with_dcn = dcn is not None
  44. self.conv1 = conv3x3(inplanes, planes, stride)
  45. self.bn1 = BatchNorm2d(planes, momentum=0.1)
  46. self.relu = nn.ReLU()
  47. self.with_modulated_dcn = False
  48. if not self.with_dcn:
  49. self.conv2 = nn.Conv2D(
  50. planes, planes, kernel_size=3, padding=1, bias_attr=False
  51. )
  52. else:
  53. from paddle.vision.ops import DeformConv2D
  54. deformable_groups = dcn.get("deformable_groups", 1)
  55. offset_channels = 18
  56. self.conv2_offset = nn.Conv2D(
  57. planes, deformable_groups * offset_channels, kernel_size=3, padding=1
  58. )
  59. self.conv2 = DeformConv2D(
  60. planes, planes, kernel_size=3, padding=1, bias_attr=False
  61. )
  62. self.bn2 = BatchNorm2d(planes, momentum=0.1)
  63. self.downsample = downsample
  64. self.stride = stride
  65. def forward(self, x):
  66. residual = x
  67. out = self.conv1(x)
  68. out = self.bn1(out)
  69. out = self.relu(out)
  70. # out = self.conv2(out)
  71. if not self.with_dcn:
  72. out = self.conv2(out)
  73. else:
  74. offset = self.conv2_offset(out)
  75. out = self.conv2(out, offset)
  76. out = self.bn2(out)
  77. if self.downsample is not None:
  78. residual = self.downsample(x)
  79. out += residual
  80. out = self.relu(out)
  81. return out
  82. class Bottleneck(nn.Layer):
  83. expansion = 4
  84. def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
  85. super(Bottleneck, self).__init__()
  86. self.with_dcn = dcn is not None
  87. self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False)
  88. self.bn1 = BatchNorm2d(planes, momentum=0.1)
  89. self.with_modulated_dcn = False
  90. if not self.with_dcn:
  91. self.conv2 = nn.Conv2D(
  92. planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False
  93. )
  94. else:
  95. deformable_groups = dcn.get("deformable_groups", 1)
  96. from paddle.vision.ops import DeformConv2D
  97. offset_channels = 18
  98. self.conv2_offset = nn.Conv2D(
  99. planes,
  100. deformable_groups * offset_channels,
  101. stride=stride,
  102. kernel_size=3,
  103. padding=1,
  104. )
  105. self.conv2 = DeformConv2D(
  106. planes, planes, kernel_size=3, padding=1, stride=stride, bias_attr=False
  107. )
  108. self.bn2 = BatchNorm2d(planes, momentum=0.1)
  109. self.conv3 = nn.Conv2D(planes, planes * 4, kernel_size=1, bias_attr=False)
  110. self.bn3 = BatchNorm2d(planes * 4, momentum=0.1)
  111. self.relu = nn.ReLU()
  112. self.downsample = downsample
  113. self.stride = stride
  114. self.dcn = dcn
  115. self.with_dcn = dcn is not None
  116. def forward(self, x):
  117. residual = x
  118. out = self.conv1(x)
  119. out = self.bn1(out)
  120. out = self.relu(out)
  121. # out = self.conv2(out)
  122. if not self.with_dcn:
  123. out = self.conv2(out)
  124. else:
  125. offset = self.conv2_offset(out)
  126. out = self.conv2(out, offset)
  127. out = self.bn2(out)
  128. out = self.relu(out)
  129. out = self.conv3(out)
  130. out = self.bn3(out)
  131. if self.downsample is not None:
  132. residual = self.downsample(x)
  133. out += residual
  134. out = self.relu(out)
  135. return out
  136. class ResNet(nn.Layer):
  137. def __init__(self, block, layers, in_channels=3, dcn=None):
  138. self.dcn = dcn
  139. self.inplanes = 64
  140. super(ResNet, self).__init__()
  141. self.out_channels = []
  142. self.conv1 = nn.Conv2D(
  143. in_channels, 64, kernel_size=7, stride=2, padding=3, bias_attr=False
  144. )
  145. self.bn1 = BatchNorm2d(64, momentum=0.1)
  146. self.relu = nn.ReLU()
  147. self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
  148. self.layer1 = self._make_layer(block, 64, layers[0])
  149. self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn)
  150. self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn)
  151. self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn)
  152. if self.dcn is not None:
  153. for m in self.modules():
  154. if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
  155. if hasattr(m, "conv2_offset"):
  156. constant_init(m.conv2_offset, 0)
  157. def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
  158. downsample = None
  159. if stride != 1 or self.inplanes != planes * block.expansion:
  160. downsample = nn.Sequential(
  161. nn.Conv2D(
  162. self.inplanes,
  163. planes * block.expansion,
  164. kernel_size=1,
  165. stride=stride,
  166. bias_attr=False,
  167. ),
  168. BatchNorm2d(planes * block.expansion, momentum=0.1),
  169. )
  170. layers = []
  171. layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn))
  172. self.inplanes = planes * block.expansion
  173. for i in range(1, blocks):
  174. layers.append(block(self.inplanes, planes, dcn=dcn))
  175. self.out_channels.append(planes * block.expansion)
  176. return nn.Sequential(*layers)
  177. def forward(self, x):
  178. x = self.conv1(x)
  179. x = self.bn1(x)
  180. x = self.relu(x)
  181. x = self.maxpool(x)
  182. x2 = self.layer1(x)
  183. x3 = self.layer2(x2)
  184. x4 = self.layer3(x3)
  185. x5 = self.layer4(x4)
  186. return x2, x3, x4, x5
  187. def load_torch_params(paddle_model, torch_patams):
  188. paddle_params = paddle_model.state_dict()
  189. fc_names = ["classifier"]
  190. for key, torch_value in torch_patams.items():
  191. if "num_batches_tracked" in key:
  192. continue
  193. key = (
  194. key.replace("running_var", "_variance")
  195. .replace("running_mean", "_mean")
  196. .replace("module.", "")
  197. )
  198. torch_value = torch_value.detach().cpu().numpy()
  199. if key in paddle_params:
  200. flag = [i in key for i in fc_names]
  201. if any(flag) and "weight" in key: # ignore bias
  202. new_shape = [1, 0] + list(range(2, torch_value.ndim))
  203. print(
  204. f"name: {key}, ori shape: {torch_value.shape}, new shape: {torch_value.transpose(new_shape).shape}"
  205. )
  206. torch_value = torch_value.transpose(new_shape)
  207. paddle_params[key] = torch_value
  208. else:
  209. print(f"{key} not in paddle")
  210. paddle_model.set_state_dict(paddle_params)
  211. def load_models(model, model_name):
  212. import torch.utils.model_zoo as model_zoo
  213. torch_patams = model_zoo.load_url(model_urls[model_name])
  214. load_torch_params(model, torch_patams)
  215. def resnet18(pretrained=True, **kwargs):
  216. """Constructs a ResNet-18 model.
  217. Args:
  218. pretrained (bool): If True, returns a model pre-trained on ImageNet
  219. """
  220. model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  221. if pretrained:
  222. assert (
  223. kwargs.get("in_channels", 3) == 3
  224. ), "in_channels must be 3 when pretrained is True"
  225. print("load from imagenet")
  226. load_models(model, "resnet18")
  227. return model
  228. def deformable_resnet18(pretrained=True, **kwargs):
  229. """Constructs a ResNet-18 model.
  230. Args:
  231. pretrained (bool): If True, returns a model pre-trained on ImageNet
  232. """
  233. model = ResNet(BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs)
  234. if pretrained:
  235. assert (
  236. kwargs.get("in_channels", 3) == 3
  237. ), "in_channels must be 3 when pretrained is True"
  238. print("load from imagenet")
  239. model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]), strict=False)
  240. return model
  241. def resnet34(pretrained=True, **kwargs):
  242. """Constructs a ResNet-34 model.
  243. Args:
  244. pretrained (bool): If True, returns a model pre-trained on ImageNet
  245. """
  246. model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
  247. if pretrained:
  248. assert (
  249. kwargs.get("in_channels", 3) == 3
  250. ), "in_channels must be 3 when pretrained is True"
  251. model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]), strict=False)
  252. return model
  253. def resnet50(pretrained=True, **kwargs):
  254. """Constructs a ResNet-50 model.
  255. Args:
  256. pretrained (bool): If True, returns a model pre-trained on ImageNet
  257. """
  258. model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
  259. if pretrained:
  260. assert (
  261. kwargs.get("in_channels", 3) == 3
  262. ), "in_channels must be 3 when pretrained is True"
  263. load_models(model, "resnet50")
  264. return model
  265. def deformable_resnet50(pretrained=True, **kwargs):
  266. """Constructs a ResNet-50 model with deformable conv.
  267. Args:
  268. pretrained (bool): If True, returns a model pre-trained on ImageNet
  269. """
  270. model = ResNet(Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs)
  271. if pretrained:
  272. assert (
  273. kwargs.get("in_channels", 3) == 3
  274. ), "in_channels must be 3 when pretrained is True"
  275. model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]), strict=False)
  276. return model
  277. def resnet101(pretrained=True, **kwargs):
  278. """Constructs a ResNet-101 model.
  279. Args:
  280. pretrained (bool): If True, returns a model pre-trained on ImageNet
  281. """
  282. model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
  283. if pretrained:
  284. assert (
  285. kwargs.get("in_channels", 3) == 3
  286. ), "in_channels must be 3 when pretrained is True"
  287. model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]), strict=False)
  288. return model
  289. def resnet152(pretrained=True, **kwargs):
  290. """Constructs a ResNet-152 model.
  291. Args:
  292. pretrained (bool): If True, returns a model pre-trained on ImageNet
  293. """
  294. model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
  295. if pretrained:
  296. assert (
  297. kwargs.get("in_channels", 3) == 3
  298. ), "in_channels must be 3 when pretrained is True"
  299. model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]), strict=False)
  300. return model
  301. if __name__ == "__main__":
  302. x = paddle.zeros([2, 3, 640, 640])
  303. net = resnet50(pretrained=True)
  304. y = net(x)
  305. for u in y:
  306. print(u.shape)
  307. print(net.out_channels)