rec_resnetv2.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnetv2.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import math
  22. import collections.abc
  23. from itertools import repeat
  24. from collections import OrderedDict # pylint: disable=g-importing-member
  25. import paddle
  26. import paddle.nn as nn
  27. import paddle.nn.functional as F
  28. from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingUniform
  29. from functools import partial
  30. from typing import Union, Callable, Type, List, Tuple
  31. IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
  32. IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
  33. normal_ = Normal(mean=0.0, std=0.01)
  34. zeros_ = Constant(value=0.0)
  35. ones_ = Constant(value=1.0)
  36. kaiming_normal_ = KaimingUniform(nonlinearity="relu")
  37. def _ntuple(n):
  38. def parse(x):
  39. if isinstance(x, collections.abc.Iterable):
  40. return x
  41. return tuple(repeat(x, n))
  42. return parse
  43. to_1tuple = _ntuple(1)
  44. to_2tuple = _ntuple(2)
  45. to_3tuple = _ntuple(3)
  46. to_4tuple = _ntuple(4)
  47. to_ntuple = _ntuple
  48. class StdConv2dSame(nn.Conv2D):
  49. def __init__(
  50. self,
  51. in_channel,
  52. out_channels,
  53. kernel_size,
  54. stride=1,
  55. padding="SAME",
  56. dilation=1,
  57. groups=1,
  58. bias_attr=False,
  59. eps=1e-6,
  60. is_export=False,
  61. ):
  62. padding, is_dynamic = get_padding_value(
  63. padding, kernel_size, stride=stride, dilation=dilation
  64. )
  65. super().__init__(
  66. in_channel,
  67. out_channels,
  68. kernel_size,
  69. stride=stride,
  70. padding=padding,
  71. dilation=dilation,
  72. groups=groups,
  73. bias_attr=bias_attr,
  74. )
  75. self.same_pad = is_dynamic
  76. self.export = is_export
  77. self.eps = eps
  78. self.running_mean = paddle.zeros([self._out_channels], dtype="float32")
  79. self.running_variance = paddle.ones([self._out_channels], dtype="float32")
  80. self.batch_norm = paddle.nn.BatchNorm1D(
  81. self._out_channels, use_global_stats=False
  82. )
  83. def forward(self, x):
  84. if not self.training:
  85. self.export = True
  86. if self.same_pad:
  87. if self.export:
  88. x = pad_same_export(x, self._kernel_size, self._stride, self._dilation)
  89. else:
  90. x = pad_same(x, self._kernel_size, self._stride, self._dilation)
  91. if self.export:
  92. weight = paddle.reshape(
  93. self.batch_norm(
  94. self.weight.reshape([1, self._out_channels, -1]).cast(
  95. paddle.float32
  96. ),
  97. ),
  98. self.weight.shape,
  99. )
  100. else:
  101. weight = paddle.reshape(
  102. F.batch_norm(
  103. self.weight.reshape([1, self._out_channels, -1]),
  104. self.running_mean,
  105. self.running_variance,
  106. training=True,
  107. momentum=0.0,
  108. epsilon=self.eps,
  109. ),
  110. self.weight.shape,
  111. )
  112. x = F.conv2d(
  113. x,
  114. weight,
  115. self.bias,
  116. self._stride,
  117. self._padding,
  118. self._dilation,
  119. self._groups,
  120. )
  121. return x
  122. class StdConv2d(nn.Conv2D):
  123. """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
  124. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
  125. https://arxiv.org/abs/1903.10520v2
  126. """
  127. def __init__(
  128. self,
  129. in_channel,
  130. out_channels,
  131. kernel_size,
  132. stride=1,
  133. padding=None,
  134. dilation=1,
  135. groups=1,
  136. bias=False,
  137. eps=1e-6,
  138. ):
  139. if padding is None:
  140. padding = get_padding(kernel_size, stride, dilation)
  141. super().__init__(
  142. in_channel,
  143. out_channels,
  144. kernel_size,
  145. stride=stride,
  146. padding=padding,
  147. dilation=dilation,
  148. groups=groups,
  149. bias_attr=bias,
  150. )
  151. self.eps = eps
  152. def forward(self, x):
  153. weight = F.batch_norm(
  154. self.weight.reshape(1, self.out_channels, -1),
  155. None,
  156. None,
  157. training=True,
  158. momentum=0.0,
  159. epsilon=self.eps,
  160. ).reshape_as(self.weight)
  161. x = F.conv2d(
  162. x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
  163. )
  164. return x
  165. class MaxPool2dSame(nn.MaxPool2D):
  166. """Tensorflow like 'SAME' wrapper for 2D max pooling"""
  167. def __init__(
  168. self,
  169. kernel_size: int,
  170. stride=None,
  171. padding=0,
  172. dilation=1,
  173. ceil_mode=False,
  174. is_export=False,
  175. ):
  176. kernel_size = to_2tuple(kernel_size)
  177. stride = to_2tuple(stride)
  178. dilation = to_2tuple(dilation)
  179. self.export = is_export
  180. super(MaxPool2dSame, self).__init__(
  181. kernel_size, stride, (0, 0), dilation, ceil_mode
  182. )
  183. def forward(self, x):
  184. if not self.training:
  185. self.export = True
  186. if self.export:
  187. x = pad_same_export(x, self.ksize, self.stride, value=-float("inf"))
  188. else:
  189. x = pad_same(x, self.ksize, self.stride, value=-float("inf"))
  190. return F.max_pool2d(x, self.ksize, self.stride, (0, 0), self.ceil_mode)
  191. def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
  192. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  193. return padding
  194. def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
  195. return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
  196. def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
  197. dynamic = False
  198. if isinstance(padding, str):
  199. # for any string padding, the padding will be calculated for you, one of three ways
  200. padding = padding.lower()
  201. if padding == "same":
  202. # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
  203. if is_static_pad(kernel_size, **kwargs):
  204. # static case, no extra overhead
  205. padding = get_padding(kernel_size, **kwargs)
  206. else:
  207. # dynamic 'SAME' padding, has runtime/GPU memory overhead
  208. padding = 0
  209. dynamic = True
  210. elif padding == "valid":
  211. # 'VALID' padding, same as padding=0
  212. padding = 0
  213. else:
  214. # Default to PyTorch style 'same'-ish symmetric padding
  215. padding = get_padding(kernel_size, **kwargs)
  216. return padding, dynamic
  217. def create_pool2d(pool_type, kernel_size, stride=None, is_export=False, **kwargs):
  218. stride = stride or kernel_size
  219. padding = kwargs.pop("padding", "")
  220. padding, is_dynamic = get_padding_value(
  221. padding, kernel_size, stride=stride, **kwargs
  222. )
  223. if is_dynamic:
  224. if pool_type == "avg":
  225. return AvgPool2dSame(
  226. kernel_size, stride=stride, is_export=is_export, **kwargs
  227. )
  228. elif pool_type == "max":
  229. return MaxPool2dSame(
  230. kernel_size, stride=stride, is_export=is_export, **kwargs
  231. )
  232. else:
  233. assert False, f"Unsupported pool type {pool_type}"
  234. def get_same_padding(x, k, s, d):
  235. return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
  236. def get_same_padding_export(x, k, s, d):
  237. x = paddle.to_tensor(x)
  238. k = paddle.to_tensor(k)
  239. s = paddle.to_tensor(s)
  240. d = paddle.to_tensor(d)
  241. return paddle.max((paddle.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
  242. def pad_same_export(x, k, s, d=(1, 1), value=0):
  243. ih, iw = x.shape[-2:]
  244. pad_h, pad_w = get_same_padding_export(
  245. ih, k[0], s[0], d[0]
  246. ), get_same_padding_export(iw, k[1], s[1], d[1])
  247. pad_h = pad_h.cast(paddle.int32)
  248. pad_w = pad_w.cast(paddle.int32)
  249. pad_list = paddle.to_tensor(
  250. [
  251. (pad_w // 2),
  252. (pad_w - pad_w // 2).cast(paddle.int32),
  253. (pad_h // 2).cast(paddle.int32),
  254. (pad_h - pad_h // 2).cast(paddle.int32),
  255. ]
  256. )
  257. if pad_h > 0 or pad_w > 0:
  258. if len(pad_list.shape) == 2:
  259. pad_list = pad_list.squeeze(1)
  260. x = F.pad(x, pad_list.cast(paddle.int32), value=value)
  261. return x
  262. def pad_same(x, k, s, d=(1, 1), value=0, pad_h=None, pad_w=None):
  263. ih, iw = x.shape[-2:]
  264. pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(
  265. iw, k[1], s[1], d[1]
  266. )
  267. if pad_h > 0 or pad_w > 0:
  268. x = F.pad(
  269. x,
  270. [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
  271. value=value,
  272. )
  273. return x
  274. class AvgPool2dSame(nn.AvgPool2D):
  275. """Tensorflow like 'SAME' wrapper for 2D average pooling"""
  276. def __init__(
  277. self,
  278. kernel_size: int,
  279. stride=None,
  280. padding=0,
  281. ceil_mode=False,
  282. count_include_pad=True,
  283. ):
  284. kernel_size = to_2tuple(kernel_size)
  285. stride = to_2tuple(stride)
  286. super(AvgPool2dSame, self).__init__(
  287. kernel_size, stride, (0, 0), ceil_mode, count_include_pad
  288. )
  289. def forward(self, x):
  290. x = pad_same(x, self.kernel_size, self.stride)
  291. return F.avg_pool2d(
  292. x,
  293. self.kernel_size,
  294. self.stride,
  295. self.padding,
  296. self.ceil_mode,
  297. self.count_include_pad,
  298. )
  299. def drop_path(
  300. x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
  301. ):
  302. if drop_prob == 0.0 or not training:
  303. return x
  304. keep_prob = 1 - drop_prob
  305. shape = (x.shape[0],) + (1,) * (
  306. x.ndim - 1
  307. ) # work with diff dim tensors, not just 2D ConvNets
  308. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  309. if keep_prob > 0.0 and scale_by_keep:
  310. random_tensor.div_(keep_prob)
  311. return x * random_tensor
  312. class DropPath(nn.Layer):
  313. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  314. def __init__(self, drop_prob=None, scale_by_keep=True):
  315. super(DropPath, self).__init__()
  316. self.drop_prob = drop_prob
  317. self.scale_by_keep = scale_by_keep
  318. def forward(self, x):
  319. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  320. def adaptive_pool_feat_mult(pool_type="avg"):
  321. if pool_type == "catavgmax":
  322. return 2
  323. else:
  324. return 1
  325. class SelectAdaptivePool2d(nn.Layer):
  326. """Selectable global pooling layer with dynamic input kernel size"""
  327. def __init__(self, output_size=1, pool_type="fast", flatten=False):
  328. super(SelectAdaptivePool2d, self).__init__()
  329. self.pool_type = (
  330. pool_type or ""
  331. ) # convert other falsy values to empty string for consistent TS typing
  332. self.flatten = nn.Flatten(1) if flatten else nn.Identity()
  333. if pool_type == "":
  334. self.pool = nn.Identity() # pass through
  335. def is_identity(self):
  336. return not self.pool_type
  337. def forward(self, x):
  338. x = self.pool(x)
  339. x = self.flatten(x)
  340. return x
  341. def feat_mult(self):
  342. return adaptive_pool_feat_mult(self.pool_type)
  343. def __repr__(self):
  344. return (
  345. self.__class__.__name__
  346. + " ("
  347. + "pool_type="
  348. + self.pool_type
  349. + ", flatten="
  350. + str(self.flatten)
  351. + ")"
  352. )
  353. def _create_pool(num_features, num_classes, pool_type="avg", use_conv=False):
  354. flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
  355. if not pool_type:
  356. assert (
  357. num_classes == 0 or use_conv
  358. ), "Pooling can only be disabled if classifier is also removed or conv classifier is used"
  359. flatten_in_pool = (
  360. False # disable flattening if pooling is pass-through (no pooling)
  361. )
  362. global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
  363. num_pooled_features = num_features * global_pool.feat_mult()
  364. return global_pool, num_pooled_features
  365. def _create_fc(num_features, num_classes, use_conv=False):
  366. if num_classes <= 0:
  367. fc = nn.Identity() # pass-through (no classifier)
  368. elif use_conv:
  369. fc = nn.Conv2D(num_features, num_classes, 1, bias_attr=True)
  370. else:
  371. fc = nn.Linear(num_features, num_classes, bias_attr=True)
  372. return fc
  373. class ClassifierHead(nn.Layer):
  374. """Classifier head w/ configurable global pooling and dropout."""
  375. def __init__(
  376. self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False
  377. ):
  378. super(ClassifierHead, self).__init__()
  379. self.drop_rate = drop_rate
  380. self.global_pool, num_pooled_features = _create_pool(
  381. in_chs, num_classes, pool_type, use_conv=use_conv
  382. )
  383. self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
  384. self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
  385. def forward(self, x):
  386. x = self.global_pool(x)
  387. if self.drop_rate:
  388. x = F.dropout(x, p=float(self.drop_rate), training=self.training)
  389. x = self.fc(x)
  390. x = self.flatten(x)
  391. return x
  392. class EvoNormBatch2d(nn.Layer):
  393. def __init__(
  394. self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None
  395. ):
  396. super(EvoNormBatch2d, self).__init__()
  397. self.apply_act = apply_act # apply activation (non-linearity)
  398. self.momentum = momentum
  399. self.eps = eps
  400. self.weight = paddle.create_parameter(
  401. paddle.ones(num_features), dtype="float32"
  402. )
  403. self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32")
  404. self.v = (
  405. paddle.create_parameter(paddle.ones(num_features), dtype="float32")
  406. if apply_act
  407. else None
  408. )
  409. self.register_buffer("running_var", paddle.ones([num_features]))
  410. self.reset_parameters()
  411. def reset_parameters(self):
  412. ones_(self.weight)
  413. zeros_(self.bias)
  414. if self.apply_act:
  415. ones_(self.v)
  416. def forward(self, x):
  417. x_type = x.dtype
  418. if self.v is not None:
  419. running_var = self.running_var.view(1, -1, 1, 1)
  420. if self.training:
  421. var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
  422. n = x.numel() / x.shape[1]
  423. running_var = var.detach() * self.momentum * (
  424. n / (n - 1)
  425. ) + running_var * (1 - self.momentum)
  426. self.running_var.copy_(running_var.view(self.running_var.shape))
  427. else:
  428. var = running_var
  429. v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
  430. d = x * v + (
  431. x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps
  432. ).sqrt().to(dtype=x_type)
  433. d = d.max((var + self.eps).sqrt().to(dtype=x_type))
  434. x = x / d
  435. return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
  436. class EvoNormSample2d(nn.Layer):
  437. def __init__(
  438. self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None
  439. ):
  440. super(EvoNormSample2d, self).__init__()
  441. self.apply_act = apply_act
  442. self.groups = groups
  443. self.eps = eps
  444. self.weight = paddle.create_parameter(
  445. paddle.ones(num_features), dtype="float32"
  446. )
  447. self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32")
  448. self.v = (
  449. paddle.create_parameter(paddle.ones(num_features), dtype="float32")
  450. if apply_act
  451. else None
  452. )
  453. self.reset_parameters()
  454. def reset_parameters(self):
  455. ones_(self.weight)
  456. zeros_(self.bias)
  457. if self.apply_act:
  458. ones_(self.v)
  459. def forward(self, x):
  460. B, C, H, W = x.shape
  461. if self.v is not None:
  462. n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
  463. x = x.reshape(B, self.groups, -1)
  464. x = (
  465. n.reshape(B, self.groups, -1)
  466. / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
  467. )
  468. x = x.reshape(B, C, H, W)
  469. return x * self.weight.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
  470. class GroupNormAct(nn.GroupNorm):
  471. # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
  472. def __init__(
  473. self,
  474. num_channels,
  475. num_groups=32,
  476. eps=1e-5,
  477. affine=True,
  478. apply_act=True,
  479. act_layer=nn.ReLU,
  480. drop_block=None,
  481. ):
  482. super(GroupNormAct, self).__init__(num_groups, num_channels, epsilon=eps)
  483. if affine:
  484. self.weight = paddle.create_parameter([num_channels], dtype="float32")
  485. self.bias = paddle.create_parameter([num_channels], dtype="float32")
  486. ones_(self.weight)
  487. zeros_(self.bias)
  488. if act_layer is not None and apply_act:
  489. act_args = {}
  490. self.act = act_layer(**act_args)
  491. else:
  492. self.act = nn.Identity()
  493. def forward(self, x):
  494. x = F.group_norm(
  495. x,
  496. num_groups=self._num_groups,
  497. epsilon=self._epsilon,
  498. weight=self.weight,
  499. bias=self.bias,
  500. )
  501. x = self.act(x)
  502. return x
  503. class BatchNormAct2d(nn.BatchNorm2D):
  504. def __init__(
  505. self,
  506. num_features,
  507. eps=1e-5,
  508. momentum=0.1,
  509. affine=True,
  510. track_running_stats=True,
  511. apply_act=True,
  512. act_layer=nn.ReLU,
  513. drop_block=None,
  514. ):
  515. super(BatchNormAct2d, self).__init__(
  516. num_features,
  517. epsilon=eps,
  518. momentum=momentum,
  519. use_global_stats=track_running_stats,
  520. )
  521. if act_layer is not None and apply_act:
  522. act_args = dict()
  523. self.act = act_layer(**act_args)
  524. else:
  525. self.act = nn.Identity()
  526. def _forward_python(self, x):
  527. return super(BatchNormAct2d, self).forward(x)
  528. def forward(self, x):
  529. x = self._forward_python(x)
  530. x = self.act(x)
  531. return x
  532. def adapt_input_conv(in_chans, conv_weight):
  533. conv_type = conv_weight.dtype
  534. conv_weight = (
  535. conv_weight.float()
  536. ) # Some weights are in torch.half, ensure it's float for sum on CPU
  537. O, I, J, K = conv_weight.shape
  538. if in_chans == 1:
  539. if I > 3:
  540. assert conv_weight.shape[1] % 3 == 0
  541. # For models with space2depth stems
  542. conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
  543. conv_weight = conv_weight.sum(dim=2, keepdim=False)
  544. else:
  545. conv_weight = conv_weight.sum(dim=1, keepdim=True)
  546. elif in_chans != 3:
  547. if I != 3:
  548. raise NotImplementedError("Weight format not supported by conversion.")
  549. else:
  550. # NOTE this strategy should be better than random init, but there could be other combinations of
  551. # the original RGB input layer weights that'd work better for specific cases.
  552. repeat = int(math.ceil(in_chans / 3))
  553. conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
  554. conv_weight *= 3 / float(in_chans)
  555. conv_weight = conv_weight.to(conv_type)
  556. return conv_weight
  557. def named_apply(
  558. fn: Callable, module: nn.Layer, name="", depth_first=True, include_root=False
  559. ) -> nn.Layer:
  560. if not depth_first and include_root:
  561. fn(module=module, name=name)
  562. for child_name, child_module in module.named_children():
  563. child_name = ".".join((name, child_name)) if name else child_name
  564. named_apply(
  565. fn=fn,
  566. module=child_module,
  567. name=child_name,
  568. depth_first=depth_first,
  569. include_root=True,
  570. )
  571. if depth_first and include_root:
  572. fn(module=module, name=name)
  573. return module
  574. def _cfg(url="", **kwargs):
  575. return {
  576. "url": url,
  577. "num_classes": 1000,
  578. "input_size": (3, 224, 224),
  579. "pool_size": (7, 7),
  580. "crop_pct": 0.875,
  581. "interpolation": "bilinear",
  582. "mean": IMAGENET_INCEPTION_MEAN,
  583. "std": IMAGENET_INCEPTION_STD,
  584. "first_conv": "stem.conv",
  585. "classifier": "head.fc",
  586. **kwargs,
  587. }
  588. def make_div(v, divisor=8):
  589. min_value = divisor
  590. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  591. if new_v < 0.9 * v:
  592. new_v += divisor
  593. return new_v
  594. class PreActBottleneck(nn.Layer):
  595. """Pre-activation (v2) bottleneck block.
  596. Follows the implementation of "Identity Mappings in Deep Residual Networks":
  597. https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
  598. Except it puts the stride on 3x3 conv when available.
  599. """
  600. def __init__(
  601. self,
  602. in_chs,
  603. out_chs=None,
  604. bottle_ratio=0.25,
  605. stride=1,
  606. dilation=1,
  607. first_dilation=None,
  608. groups=1,
  609. act_layer=None,
  610. conv_layer=None,
  611. norm_layer=None,
  612. proj_layer=None,
  613. drop_path_rate=0.0,
  614. is_export=False,
  615. ):
  616. super().__init__()
  617. first_dilation = first_dilation or dilation
  618. conv_layer = conv_layer or StdConv2d
  619. norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
  620. out_chs = out_chs or in_chs
  621. mid_chs = make_div(out_chs * bottle_ratio)
  622. if proj_layer is not None:
  623. self.downsample = proj_layer(
  624. in_chs,
  625. out_chs,
  626. stride=stride,
  627. dilation=dilation,
  628. first_dilation=first_dilation,
  629. preact=True,
  630. conv_layer=conv_layer,
  631. norm_layer=norm_layer,
  632. )
  633. else:
  634. self.downsample = None
  635. self.norm1 = norm_layer(in_chs)
  636. self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export)
  637. self.norm2 = norm_layer(mid_chs)
  638. self.conv2 = conv_layer(
  639. mid_chs,
  640. mid_chs,
  641. 3,
  642. stride=stride,
  643. dilation=first_dilation,
  644. groups=groups,
  645. is_export=is_export,
  646. )
  647. self.norm3 = norm_layer(mid_chs)
  648. self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export)
  649. self.drop_path = (
  650. DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  651. )
  652. def zero_init_last(self):
  653. zeros_(self.conv3.weight)
  654. def forward(self, x):
  655. x_preact = self.norm1(x)
  656. # shortcut branch
  657. shortcut = x
  658. if self.downsample is not None:
  659. shortcut = self.downsample(x_preact)
  660. # residual branch
  661. x = self.conv1(x_preact)
  662. x = self.conv2(self.norm2(x))
  663. x = self.conv3(self.norm3(x))
  664. x = self.drop_path(x)
  665. return x + shortcut
  666. class Bottleneck(nn.Layer):
  667. """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT."""
  668. def __init__(
  669. self,
  670. in_chs,
  671. out_chs=None,
  672. bottle_ratio=0.25,
  673. stride=1,
  674. dilation=1,
  675. first_dilation=None,
  676. groups=1,
  677. act_layer=None,
  678. conv_layer=None,
  679. norm_layer=None,
  680. proj_layer=None,
  681. drop_path_rate=0.0,
  682. is_export=False,
  683. ):
  684. super().__init__()
  685. first_dilation = first_dilation or dilation
  686. act_layer = act_layer or nn.ReLU
  687. conv_layer = conv_layer or StdConv2d
  688. norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
  689. out_chs = out_chs or in_chs
  690. mid_chs = make_div(out_chs * bottle_ratio)
  691. if proj_layer is not None:
  692. self.downsample = proj_layer(
  693. in_chs,
  694. out_chs,
  695. stride=stride,
  696. dilation=dilation,
  697. preact=False,
  698. conv_layer=conv_layer,
  699. norm_layer=norm_layer,
  700. is_export=is_export,
  701. )
  702. else:
  703. self.downsample = None
  704. self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export)
  705. self.norm1 = norm_layer(mid_chs)
  706. self.conv2 = conv_layer(
  707. mid_chs,
  708. mid_chs,
  709. 3,
  710. stride=stride,
  711. dilation=first_dilation,
  712. groups=groups,
  713. is_export=is_export,
  714. )
  715. self.norm2 = norm_layer(mid_chs)
  716. self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export)
  717. self.norm3 = norm_layer(out_chs, apply_act=False)
  718. self.drop_path = (
  719. DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  720. )
  721. self.act3 = act_layer()
  722. def zero_init_last(self):
  723. zeros_(self.norm3.weight)
  724. def forward(self, x):
  725. # shortcut branch
  726. shortcut = x
  727. if self.downsample is not None:
  728. shortcut = self.downsample(x)
  729. # residual
  730. x = self.conv1(x)
  731. x = self.norm1(x)
  732. x = self.conv2(x)
  733. x = self.norm2(x)
  734. x = self.conv3(x)
  735. x = self.norm3(x)
  736. x = self.drop_path(x)
  737. x = self.act3(x + shortcut)
  738. return x
  739. class DownsampleConv(nn.Layer):
  740. def __init__(
  741. self,
  742. in_chs,
  743. out_chs,
  744. stride=1,
  745. dilation=1,
  746. first_dilation=None,
  747. preact=True,
  748. conv_layer=None,
  749. norm_layer=None,
  750. is_export=False,
  751. ):
  752. super(DownsampleConv, self).__init__()
  753. self.conv = conv_layer(in_chs, out_chs, 1, stride=stride, is_export=is_export)
  754. self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
  755. def forward(self, x):
  756. return self.norm(self.conv(x))
  757. class DownsampleAvg(nn.Layer):
  758. def __init__(
  759. self,
  760. in_chs,
  761. out_chs,
  762. stride=1,
  763. dilation=1,
  764. first_dilation=None,
  765. preact=True,
  766. conv_layer=None,
  767. norm_layer=None,
  768. is_export=False,
  769. ):
  770. """AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
  771. super(DownsampleAvg, self).__init__()
  772. avg_stride = stride if dilation == 1 else 1
  773. if stride > 1 or dilation > 1:
  774. avg_pool_fn = (
  775. AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2D
  776. )
  777. self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, exclusive=False)
  778. else:
  779. self.pool = nn.Identity()
  780. self.conv = conv_layer(in_chs, out_chs, 1, stride=1, is_export=is_export)
  781. self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
  782. def forward(self, x):
  783. return self.norm(self.conv(self.pool(x)))
  784. class ResNetStage(nn.Layer):
  785. """ResNet Stage."""
  786. def __init__(
  787. self,
  788. in_chs,
  789. out_chs,
  790. stride,
  791. dilation,
  792. depth,
  793. bottle_ratio=0.25,
  794. groups=1,
  795. avg_down=False,
  796. block_dpr=None,
  797. block_fn=PreActBottleneck,
  798. is_export=False,
  799. act_layer=None,
  800. conv_layer=None,
  801. norm_layer=None,
  802. **block_kwargs,
  803. ):
  804. super(ResNetStage, self).__init__()
  805. first_dilation = 1 if dilation in (1, 2) else 2
  806. layer_kwargs = dict(
  807. act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer
  808. )
  809. proj_layer = DownsampleAvg if avg_down else DownsampleConv
  810. prev_chs = in_chs
  811. self.blocks = nn.Sequential()
  812. for block_idx in range(depth):
  813. drop_path_rate = block_dpr[block_idx] if block_dpr else 0.0
  814. stride = stride if block_idx == 0 else 1
  815. self.blocks.add_sublayer(
  816. str(block_idx),
  817. block_fn(
  818. prev_chs,
  819. out_chs,
  820. stride=stride,
  821. dilation=dilation,
  822. bottle_ratio=bottle_ratio,
  823. groups=groups,
  824. first_dilation=first_dilation,
  825. proj_layer=proj_layer,
  826. drop_path_rate=drop_path_rate,
  827. is_export=is_export,
  828. **layer_kwargs,
  829. **block_kwargs,
  830. ),
  831. )
  832. prev_chs = out_chs
  833. first_dilation = dilation
  834. proj_layer = None
  835. def forward(self, x):
  836. x = self.blocks(x)
  837. return x
  838. def is_stem_deep(stem_type):
  839. return any([s in stem_type for s in ("deep", "tiered")])
  840. def create_resnetv2_stem(
  841. in_chs,
  842. out_chs=64,
  843. stem_type="",
  844. preact=True,
  845. conv_layer=StdConv2d,
  846. norm_layer=partial(GroupNormAct, num_groups=32),
  847. is_export=False,
  848. ):
  849. stem = OrderedDict()
  850. assert stem_type in (
  851. "",
  852. "fixed",
  853. "same",
  854. "deep",
  855. "deep_fixed",
  856. "deep_same",
  857. "tiered",
  858. )
  859. # NOTE conv padding mode can be changed by overriding the conv_layer def
  860. if is_stem_deep(stem_type):
  861. # A 3 deep 3x3 conv stack as in ResNet V1D models
  862. if "tiered" in stem_type:
  863. stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
  864. else:
  865. stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
  866. stem["conv1"] = conv_layer(
  867. in_chs, stem_chs[0], kernel_size=3, stride=2, is_export=is_export
  868. )
  869. stem["norm1"] = norm_layer(stem_chs[0])
  870. stem["conv2"] = conv_layer(
  871. stem_chs[0], stem_chs[1], kernel_size=3, stride=1, is_export=is_export
  872. )
  873. stem["norm2"] = norm_layer(stem_chs[1])
  874. stem["conv3"] = conv_layer(
  875. stem_chs[1], out_chs, kernel_size=3, stride=1, is_export=is_export
  876. )
  877. if not preact:
  878. stem["norm3"] = norm_layer(out_chs)
  879. else:
  880. # The usual 7x7 stem conv
  881. stem["conv"] = conv_layer(
  882. in_chs, out_chs, kernel_size=7, stride=2, is_export=is_export
  883. )
  884. if not preact:
  885. stem["norm"] = norm_layer(out_chs)
  886. if "fixed" in stem_type:
  887. # 'fixed' SAME padding approximation that is used in BiT models
  888. stem["pad"] = paddle.nn.Pad2D(
  889. 1, mode="constant", value=0.0, data_format="NCHW", name=None
  890. )
  891. stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=0)
  892. elif "same" in stem_type:
  893. # full, input size based 'SAME' padding, used in ViT Hybrid model
  894. stem["pool"] = create_pool2d(
  895. "max", kernel_size=3, stride=2, padding="same", is_export=is_export
  896. )
  897. else:
  898. # the usual Pypaddle symmetric padding
  899. stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
  900. stem_seq = nn.Sequential()
  901. for key, value in stem.items():
  902. stem_seq.add_sublayer(key, value)
  903. return stem_seq
  904. class ResNetV2(nn.Layer):
  905. """Implementation of Pre-activation (v2) ResNet mode.
  906. Args:
  907. x: input images with shape [N, 1, H, W]
  908. Returns:
  909. The extracted features [N, 1, H//16, W//16]
  910. """
  911. def __init__(
  912. self,
  913. layers,
  914. channels=(256, 512, 1024, 2048),
  915. num_classes=1000,
  916. in_chans=3,
  917. global_pool="avg",
  918. output_stride=32,
  919. width_factor=1,
  920. stem_chs=64,
  921. stem_type="",
  922. avg_down=False,
  923. preact=True,
  924. act_layer=nn.ReLU,
  925. conv_layer=StdConv2d,
  926. norm_layer=partial(GroupNormAct, num_groups=32),
  927. drop_rate=0.0,
  928. drop_path_rate=0.0,
  929. zero_init_last=False,
  930. is_export=False,
  931. ):
  932. super().__init__()
  933. self.num_classes = num_classes
  934. self.drop_rate = drop_rate
  935. self.is_export = is_export
  936. wf = width_factor
  937. self.feature_info = []
  938. stem_chs = make_div(stem_chs * wf)
  939. self.stem = create_resnetv2_stem(
  940. in_chans,
  941. stem_chs,
  942. stem_type,
  943. preact,
  944. conv_layer=conv_layer,
  945. norm_layer=norm_layer,
  946. is_export=is_export,
  947. )
  948. stem_feat = (
  949. ("stem.conv3" if is_stem_deep(stem_type) else "stem.conv")
  950. if preact
  951. else "stem.norm"
  952. )
  953. self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
  954. prev_chs = stem_chs
  955. curr_stride = 4
  956. dilation = 1
  957. block_dprs = [
  958. x.tolist()
  959. for x in paddle.linspace(0, drop_path_rate, sum(layers)).split(layers)
  960. ]
  961. block_fn = PreActBottleneck if preact else Bottleneck
  962. self.stages = nn.Sequential()
  963. for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
  964. out_chs = make_div(c * wf)
  965. stride = 1 if stage_idx == 0 else 2
  966. if curr_stride >= output_stride:
  967. dilation *= stride
  968. stride = 1
  969. stage = ResNetStage(
  970. prev_chs,
  971. out_chs,
  972. stride=stride,
  973. dilation=dilation,
  974. depth=d,
  975. avg_down=avg_down,
  976. act_layer=act_layer,
  977. conv_layer=conv_layer,
  978. norm_layer=norm_layer,
  979. block_dpr=bdpr,
  980. block_fn=block_fn,
  981. is_export=is_export,
  982. )
  983. prev_chs = out_chs
  984. curr_stride *= stride
  985. self.feature_info += [
  986. dict(
  987. num_chs=prev_chs,
  988. reduction=curr_stride,
  989. module=f"stages.{stage_idx}",
  990. )
  991. ]
  992. self.stages.add_sublayer(str(stage_idx), stage)
  993. self.num_features = prev_chs
  994. self.norm = norm_layer(self.num_features) if preact else nn.Identity()
  995. self.head = ClassifierHead(
  996. self.num_features,
  997. num_classes,
  998. pool_type=global_pool,
  999. drop_rate=self.drop_rate,
  1000. use_conv=True,
  1001. )
  1002. self.init_weights(zero_init_last=zero_init_last)
  1003. def init_weights(self, zero_init_last=True):
  1004. named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
  1005. def load_pretrained(self, checkpoint_path, prefix="resnet/"):
  1006. _load_weights(self, checkpoint_path, prefix)
  1007. def get_classifier(self):
  1008. return self.head.fc
  1009. def reset_classifier(self, num_classes, global_pool="avg"):
  1010. self.num_classes = num_classes
  1011. self.head = ClassifierHead(
  1012. self.num_features,
  1013. num_classes,
  1014. pool_type=global_pool,
  1015. drop_rate=self.drop_rate,
  1016. use_conv=True,
  1017. )
  1018. def forward_features(self, x):
  1019. x = self.stem(x)
  1020. x = self.stages(x)
  1021. x = self.norm(x)
  1022. return x
  1023. def forward(self, x):
  1024. x = self.forward_features(x)
  1025. x = self.head(x)
  1026. return x
  1027. def _init_weights(module: nn.Layer, name: str = "", zero_init_last=True):
  1028. if isinstance(module, nn.Linear) or (
  1029. "head.fc" in name and isinstance(module, nn.Conv2D)
  1030. ):
  1031. normal_(module.weight)
  1032. zeros_(module.bias)
  1033. elif isinstance(module, nn.Conv2D):
  1034. kaiming_normal_(module.weight)
  1035. if module.bias is not None:
  1036. zeros_(module.bias)
  1037. elif isinstance(module, (nn.BatchNorm2D, nn.LayerNorm, nn.GroupNorm)):
  1038. ones_(module.weight)
  1039. zeros_(module.bias)
  1040. elif zero_init_last and hasattr(module, "zero_init_last"):
  1041. module.zero_init_last()
  1042. @paddle.no_grad()
  1043. def _load_weights(model: nn.Layer, checkpoint_path: str, prefix: str = "resnet/"):
  1044. import numpy as np
  1045. def t2p(conv_weights):
  1046. """Possibly convert HWIO to OIHW."""
  1047. if conv_weights.ndim == 4:
  1048. conv_weights = conv_weights.transpose([3, 2, 0, 1])
  1049. return paddle.to_tensor(conv_weights)
  1050. weights = np.load(checkpoint_path)
  1051. stem_conv_w = adapt_input_conv(
  1052. model.stem.conv.weight.shape[1],
  1053. t2p(weights[f"{prefix}root_block/standardized_conv2d/kernel"]),
  1054. )
  1055. model.stem.conv.weight.copy_(stem_conv_w)
  1056. model.norm.weight.copy_(t2p(weights[f"{prefix}group_norm/gamma"]))
  1057. model.norm.bias.copy_(t2p(weights[f"{prefix}group_norm/beta"]))
  1058. if (
  1059. isinstance(getattr(model.head, "fc", None), nn.Conv2D)
  1060. and model.head.fc.weight.shape[0]
  1061. == weights[f"{prefix}head/conv2d/kernel"].shape[-1]
  1062. ):
  1063. model.head.fc.weight.copy_(t2p(weights[f"{prefix}head/conv2d/kernel"]))
  1064. model.head.fc.bias.copy_(t2p(weights[f"{prefix}head/conv2d/bias"]))
  1065. for i, (sname, stage) in enumerate(model.stages.named_children()):
  1066. for j, (bname, block) in enumerate(stage.blocks.named_children()):
  1067. cname = "standardized_conv2d"
  1068. block_prefix = f"{prefix}block{i + 1}/unit{j + 1:02d}/"
  1069. block.conv1.weight.copy_(t2p(weights[f"{block_prefix}a/{cname}/kernel"]))
  1070. block.conv2.weight.copy_(t2p(weights[f"{block_prefix}b/{cname}/kernel"]))
  1071. block.conv3.weight.copy_(t2p(weights[f"{block_prefix}c/{cname}/kernel"]))
  1072. block.norm1.weight.copy_(t2p(weights[f"{block_prefix}a/group_norm/gamma"]))
  1073. block.norm2.weight.copy_(t2p(weights[f"{block_prefix}b/group_norm/gamma"]))
  1074. block.norm3.weight.copy_(t2p(weights[f"{block_prefix}c/group_norm/gamma"]))
  1075. block.norm1.bias.copy_(t2p(weights[f"{block_prefix}a/group_norm/beta"]))
  1076. block.norm2.bias.copy_(t2p(weights[f"{block_prefix}b/group_norm/beta"]))
  1077. block.norm3.bias.copy_(t2p(weights[f"{block_prefix}c/group_norm/beta"]))
  1078. if block.downsample is not None:
  1079. w = weights[f"{block_prefix}a/proj/{cname}/kernel"]
  1080. block.downsample.conv.weight.copy_(t2p(w))