rec_repvit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/THU-MIG/RepViT
  17. """
  18. import paddle.nn as nn
  19. import paddle
  20. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  21. trunc_normal_ = TruncatedNormal(std=0.02)
  22. normal_ = Normal
  23. zeros_ = Constant(value=0.0)
  24. ones_ = Constant(value=1.0)
  25. def _make_divisible(v, divisor, min_value=None):
  26. """
  27. This function is taken from the original tf repo.
  28. It ensures that all layers have a channel number that is divisible by 8
  29. It can be seen here:
  30. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  31. :param v:
  32. :param divisor:
  33. :param min_value:
  34. :return:
  35. """
  36. if min_value is None:
  37. min_value = divisor
  38. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  39. # Make sure that round down does not go down by more than 10%.
  40. if new_v < 0.9 * v:
  41. new_v += divisor
  42. return new_v
  43. # from timm.models.layers import SqueezeExcite
  44. def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
  45. min_value = min_value or divisor
  46. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  47. # Make sure that round down does not go down by more than 10%.
  48. if new_v < round_limit * v:
  49. new_v += divisor
  50. return new_v
  51. class SEModule(nn.Layer):
  52. """SE Module as defined in original SE-Nets with a few additions
  53. Additions include:
  54. * divisor can be specified to keep channels % div == 0 (default: 8)
  55. * reduction channels can be specified directly by arg (if rd_channels is set)
  56. * reduction channels can be specified by float rd_ratio (default: 1/16)
  57. * global max pooling can be added to the squeeze aggregation
  58. * customizable activation, normalization, and gate layer
  59. """
  60. def __init__(
  61. self,
  62. channels,
  63. rd_ratio=1.0 / 16,
  64. rd_channels=None,
  65. rd_divisor=8,
  66. act_layer=nn.ReLU,
  67. ):
  68. super(SEModule, self).__init__()
  69. if not rd_channels:
  70. rd_channels = make_divisible(
  71. channels * rd_ratio, rd_divisor, round_limit=0.0
  72. )
  73. self.fc1 = nn.Conv2D(channels, rd_channels, kernel_size=1, bias_attr=True)
  74. self.act = act_layer()
  75. self.fc2 = nn.Conv2D(rd_channels, channels, kernel_size=1, bias_attr=True)
  76. def forward(self, x):
  77. x_se = x.mean((2, 3), keepdim=True)
  78. x_se = self.fc1(x_se)
  79. x_se = self.act(x_se)
  80. x_se = self.fc2(x_se)
  81. return x * nn.functional.sigmoid(x_se)
  82. class Conv2D_BN(nn.Sequential):
  83. def __init__(
  84. self,
  85. a,
  86. b,
  87. ks=1,
  88. stride=1,
  89. pad=0,
  90. dilation=1,
  91. groups=1,
  92. bn_weight_init=1,
  93. resolution=-10000,
  94. ):
  95. super().__init__()
  96. self.add_sublayer(
  97. "c", nn.Conv2D(a, b, ks, stride, pad, dilation, groups, bias_attr=False)
  98. )
  99. self.add_sublayer("bn", nn.BatchNorm2D(b))
  100. if bn_weight_init == 1:
  101. ones_(self.bn.weight)
  102. else:
  103. zeros_(self.bn.weight)
  104. zeros_(self.bn.bias)
  105. @paddle.no_grad()
  106. def fuse(self):
  107. c, bn = self.c, self.bn
  108. w = bn.weight / (bn._variance + bn._epsilon) ** 0.5
  109. w = c.weight * w[:, None, None, None]
  110. b = bn.bias - bn._mean * bn.weight / (bn._variance + bn._epsilon) ** 0.5
  111. m = nn.Conv2D(
  112. w.shape[1] * self.c._groups,
  113. w.shape[0],
  114. w.shape[2:],
  115. stride=self.c._stride,
  116. padding=self.c._padding,
  117. dilation=self.c._dilation,
  118. groups=self.c._groups,
  119. )
  120. m.weight.set_value(w)
  121. m.bias.set_value(b)
  122. return m
  123. class Residual(nn.Layer):
  124. def __init__(self, m, drop=0.0):
  125. super().__init__()
  126. self.m = m
  127. self.drop = drop
  128. def forward(self, x):
  129. if self.training and self.drop > 0:
  130. return (
  131. x
  132. + self.m(x)
  133. * paddle.rand(x.size(0), 1, 1, 1)
  134. .ge_(self.drop)
  135. .div(1 - self.drop)
  136. .detach()
  137. )
  138. else:
  139. return x + self.m(x)
  140. @paddle.no_grad()
  141. def fuse(self):
  142. if isinstance(self.m, Conv2D_BN):
  143. m = self.m.fuse()
  144. assert m._groups == m.in_channels
  145. identity = paddle.ones([m.weight.shape[0], m.weight.shape[1], 1, 1])
  146. identity = nn.functional.pad(identity, [1, 1, 1, 1])
  147. m.weight += identity
  148. return m
  149. elif isinstance(self.m, nn.Conv2D):
  150. m = self.m
  151. assert m._groups != m.in_channels
  152. identity = paddle.ones([m.weight.shape[0], m.weight.shape[1], 1, 1])
  153. identity = nn.functional.pad(identity, [1, 1, 1, 1])
  154. m.weight += identity
  155. return m
  156. else:
  157. return self
  158. class RepVGGDW(nn.Layer):
  159. def __init__(self, ed) -> None:
  160. super().__init__()
  161. self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed)
  162. self.conv1 = nn.Conv2D(ed, ed, 1, 1, 0, groups=ed)
  163. self.dim = ed
  164. self.bn = nn.BatchNorm2D(ed)
  165. def forward(self, x):
  166. return self.bn((self.conv(x) + self.conv1(x)) + x)
  167. @paddle.no_grad()
  168. def fuse(self):
  169. conv = self.conv.fuse()
  170. conv1 = self.conv1
  171. conv_w = conv.weight
  172. conv_b = conv.bias
  173. conv1_w = conv1.weight
  174. conv1_b = conv1.bias
  175. conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
  176. identity = nn.functional.pad(
  177. paddle.ones([conv1_w.shape[0], conv1_w.shape[1], 1, 1]), [1, 1, 1, 1]
  178. )
  179. final_conv_w = conv_w + conv1_w + identity
  180. final_conv_b = conv_b + conv1_b
  181. conv.weight.set_value(final_conv_w)
  182. conv.bias.set_value(final_conv_b)
  183. bn = self.bn
  184. w = bn.weight / (bn._variance + bn._epsilon) ** 0.5
  185. w = conv.weight * w[:, None, None, None]
  186. b = (
  187. bn.bias
  188. + (conv.bias - bn._mean) * bn.weight / (bn._variance + bn._epsilon) ** 0.5
  189. )
  190. conv.weight.set_value(w)
  191. conv.bias.set_value(b)
  192. return conv
  193. class RepViTBlock(nn.Layer):
  194. def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
  195. super(RepViTBlock, self).__init__()
  196. self.identity = stride == 1 and inp == oup
  197. assert hidden_dim == 2 * inp
  198. if stride != 1:
  199. self.token_mixer = nn.Sequential(
  200. Conv2D_BN(
  201. inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp
  202. ),
  203. SEModule(inp, 0.25) if use_se else nn.Identity(),
  204. Conv2D_BN(inp, oup, ks=1, stride=1, pad=0),
  205. )
  206. self.channel_mixer = Residual(
  207. nn.Sequential(
  208. # pw
  209. Conv2D_BN(oup, 2 * oup, 1, 1, 0),
  210. nn.GELU() if use_hs else nn.GELU(),
  211. # pw-linear
  212. Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
  213. )
  214. )
  215. else:
  216. assert self.identity
  217. self.token_mixer = nn.Sequential(
  218. RepVGGDW(inp),
  219. SEModule(inp, 0.25) if use_se else nn.Identity(),
  220. )
  221. self.channel_mixer = Residual(
  222. nn.Sequential(
  223. # pw
  224. Conv2D_BN(inp, hidden_dim, 1, 1, 0),
  225. nn.GELU() if use_hs else nn.GELU(),
  226. # pw-linear
  227. Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
  228. )
  229. )
  230. def forward(self, x):
  231. return self.channel_mixer(self.token_mixer(x))
  232. class RepViT(nn.Layer):
  233. def __init__(self, cfgs, in_channels=3, out_indices=None):
  234. super(RepViT, self).__init__()
  235. # setting of inverted residual blocks
  236. self.cfgs = cfgs
  237. # building first layer
  238. input_channel = self.cfgs[0][2]
  239. patch_embed = nn.Sequential(
  240. Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1),
  241. nn.GELU(),
  242. Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1),
  243. )
  244. layers = [patch_embed]
  245. # building inverted residual blocks
  246. block = RepViTBlock
  247. for k, t, c, use_se, use_hs, s in self.cfgs:
  248. output_channel = _make_divisible(c, 8)
  249. exp_size = _make_divisible(input_channel * t, 8)
  250. layers.append(
  251. block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)
  252. )
  253. input_channel = output_channel
  254. self.features = nn.LayerList(layers)
  255. self.out_indices = out_indices
  256. if out_indices is not None:
  257. self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices]
  258. else:
  259. self.out_channels = self.cfgs[-1][2]
  260. def forward(self, x):
  261. if self.out_indices is not None:
  262. return self.forward_det(x)
  263. return self.forward_rec(x)
  264. def forward_det(self, x):
  265. outs = []
  266. for i, f in enumerate(self.features):
  267. x = f(x)
  268. if i in self.out_indices:
  269. outs.append(x)
  270. return outs
  271. def forward_rec(self, x):
  272. for f in self.features:
  273. x = f(x)
  274. h = x.shape[2]
  275. x = nn.functional.avg_pool2d(x, [h, 2])
  276. return x
  277. def RepSVTR(in_channels=3):
  278. """
  279. Constructs a MobileNetV3-Large model
  280. """
  281. # k, t, c, SE, HS, s
  282. cfgs = [
  283. [3, 2, 96, 1, 0, 1],
  284. [3, 2, 96, 0, 0, 1],
  285. [3, 2, 96, 0, 0, 1],
  286. [3, 2, 192, 0, 1, (2, 1)],
  287. [3, 2, 192, 1, 1, 1],
  288. [3, 2, 192, 0, 1, 1],
  289. [3, 2, 192, 1, 1, 1],
  290. [3, 2, 192, 0, 1, 1],
  291. [3, 2, 192, 1, 1, 1],
  292. [3, 2, 192, 0, 1, 1],
  293. [3, 2, 384, 0, 1, (2, 1)],
  294. [3, 2, 384, 1, 1, 1],
  295. [3, 2, 384, 0, 1, 1],
  296. ]
  297. return RepViT(cfgs, in_channels=in_channels)
  298. def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]):
  299. """
  300. Constructs a MobileNetV3-Large model
  301. """
  302. # k, t, c, SE, HS, s
  303. cfgs = [
  304. [3, 2, 48, 1, 0, 1],
  305. [3, 2, 48, 0, 0, 1],
  306. [3, 2, 96, 0, 0, 2],
  307. [3, 2, 96, 1, 0, 1],
  308. [3, 2, 96, 0, 0, 1],
  309. [3, 2, 192, 0, 1, 2],
  310. [3, 2, 192, 1, 1, 1],
  311. [3, 2, 192, 0, 1, 1],
  312. [3, 2, 192, 1, 1, 1],
  313. [3, 2, 192, 0, 1, 1],
  314. [3, 2, 384, 0, 1, 2],
  315. [3, 2, 384, 1, 1, 1],
  316. [3, 2, 384, 0, 1, 1],
  317. ]
  318. return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices)