rec_svtrnet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import ParamAttr
  15. from paddle.nn.initializer import KaimingNormal
  16. import numpy as np
  17. import paddle
  18. import paddle.nn as nn
  19. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  20. trunc_normal_ = TruncatedNormal(std=0.02)
  21. normal_ = Normal
  22. zeros_ = Constant(value=0.0)
  23. ones_ = Constant(value=1.0)
  24. def drop_path(x, drop_prob=0.0, training=False):
  25. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  26. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  27. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  28. """
  29. if drop_prob == 0.0 or not training:
  30. return x
  31. keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
  32. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  33. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  34. random_tensor = paddle.floor(random_tensor) # binarize
  35. output = x.divide(keep_prob) * random_tensor
  36. return output
  37. class ConvBNLayer(nn.Layer):
  38. def __init__(
  39. self,
  40. in_channels,
  41. out_channels,
  42. kernel_size=3,
  43. stride=1,
  44. padding=0,
  45. bias_attr=False,
  46. groups=1,
  47. act=nn.GELU,
  48. ):
  49. super().__init__()
  50. self.conv = nn.Conv2D(
  51. in_channels=in_channels,
  52. out_channels=out_channels,
  53. kernel_size=kernel_size,
  54. stride=stride,
  55. padding=padding,
  56. groups=groups,
  57. weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
  58. bias_attr=bias_attr,
  59. )
  60. self.norm = nn.BatchNorm2D(out_channels)
  61. self.act = act()
  62. def forward(self, inputs):
  63. out = self.conv(inputs)
  64. out = self.norm(out)
  65. out = self.act(out)
  66. return out
  67. class DropPath(nn.Layer):
  68. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  69. def __init__(self, drop_prob=None):
  70. super(DropPath, self).__init__()
  71. self.drop_prob = drop_prob
  72. def forward(self, x):
  73. return drop_path(x, self.drop_prob, self.training)
  74. class Identity(nn.Layer):
  75. def __init__(self):
  76. super(Identity, self).__init__()
  77. def forward(self, input):
  78. return input
  79. class Mlp(nn.Layer):
  80. def __init__(
  81. self,
  82. in_features,
  83. hidden_features=None,
  84. out_features=None,
  85. act_layer=nn.GELU,
  86. drop=0.0,
  87. ):
  88. super().__init__()
  89. out_features = out_features or in_features
  90. hidden_features = hidden_features or in_features
  91. self.fc1 = nn.Linear(in_features, hidden_features)
  92. self.act = act_layer()
  93. self.fc2 = nn.Linear(hidden_features, out_features)
  94. self.drop = nn.Dropout(drop)
  95. def forward(self, x):
  96. x = self.fc1(x)
  97. x = self.act(x)
  98. x = self.drop(x)
  99. x = self.fc2(x)
  100. x = self.drop(x)
  101. return x
  102. class ConvMixer(nn.Layer):
  103. def __init__(
  104. self,
  105. dim,
  106. num_heads=8,
  107. HW=[8, 25],
  108. local_k=[3, 3],
  109. ):
  110. super().__init__()
  111. self.HW = HW
  112. self.dim = dim
  113. self.local_mixer = nn.Conv2D(
  114. dim,
  115. dim,
  116. local_k,
  117. 1,
  118. [local_k[0] // 2, local_k[1] // 2],
  119. groups=num_heads,
  120. weight_attr=ParamAttr(initializer=KaimingNormal()),
  121. )
  122. def forward(self, x):
  123. h = self.HW[0]
  124. w = self.HW[1]
  125. x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
  126. x = self.local_mixer(x)
  127. x = x.flatten(2).transpose([0, 2, 1])
  128. return x
  129. class Attention(nn.Layer):
  130. def __init__(
  131. self,
  132. dim,
  133. num_heads=8,
  134. mixer="Global",
  135. HW=None,
  136. local_k=[7, 11],
  137. qkv_bias=False,
  138. qk_scale=None,
  139. attn_drop=0.0,
  140. proj_drop=0.0,
  141. ):
  142. super().__init__()
  143. self.num_heads = num_heads
  144. self.dim = dim
  145. self.head_dim = dim // num_heads
  146. self.scale = qk_scale or self.head_dim**-0.5
  147. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  148. self.attn_drop = nn.Dropout(attn_drop)
  149. self.proj = nn.Linear(dim, dim)
  150. self.proj_drop = nn.Dropout(proj_drop)
  151. self.HW = HW
  152. if HW is not None:
  153. H = HW[0]
  154. W = HW[1]
  155. self.N = H * W
  156. self.C = dim
  157. if mixer == "Local" and HW is not None:
  158. hk = local_k[0]
  159. wk = local_k[1]
  160. mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype="float32")
  161. for h in range(0, H):
  162. for w in range(0, W):
  163. mask[h * W + w, h : h + hk, w : w + wk] = 0.0
  164. mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
  165. 1
  166. )
  167. mask_inf = paddle.full([H * W, H * W], "-inf", dtype="float32")
  168. mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf)
  169. self.mask = mask.unsqueeze([0, 1])
  170. self.mixer = mixer
  171. def forward(self, x):
  172. qkv = (
  173. self.qkv(x)
  174. .reshape((0, -1, 3, self.num_heads, self.head_dim))
  175. .transpose((2, 0, 3, 1, 4))
  176. )
  177. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  178. attn = q.matmul(k.transpose((0, 1, 3, 2)))
  179. if self.mixer == "Local":
  180. attn += self.mask
  181. attn = nn.functional.softmax(attn, axis=-1)
  182. attn = self.attn_drop(attn)
  183. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
  184. x = self.proj(x)
  185. x = self.proj_drop(x)
  186. return x
  187. class Block(nn.Layer):
  188. def __init__(
  189. self,
  190. dim,
  191. num_heads,
  192. mixer="Global",
  193. local_mixer=[7, 11],
  194. HW=None,
  195. mlp_ratio=4.0,
  196. qkv_bias=False,
  197. qk_scale=None,
  198. drop=0.0,
  199. attn_drop=0.0,
  200. drop_path=0.0,
  201. act_layer=nn.GELU,
  202. norm_layer="nn.LayerNorm",
  203. epsilon=1e-6,
  204. prenorm=True,
  205. ):
  206. super().__init__()
  207. if isinstance(norm_layer, str):
  208. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  209. else:
  210. self.norm1 = norm_layer(dim)
  211. if mixer == "Global" or mixer == "Local":
  212. self.mixer = Attention(
  213. dim,
  214. num_heads=num_heads,
  215. mixer=mixer,
  216. HW=HW,
  217. local_k=local_mixer,
  218. qkv_bias=qkv_bias,
  219. qk_scale=qk_scale,
  220. attn_drop=attn_drop,
  221. proj_drop=drop,
  222. )
  223. elif mixer == "Conv":
  224. self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
  225. else:
  226. raise TypeError("The mixer must be one of [Global, Local, Conv]")
  227. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  228. if isinstance(norm_layer, str):
  229. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  230. else:
  231. self.norm2 = norm_layer(dim)
  232. mlp_hidden_dim = int(dim * mlp_ratio)
  233. self.mlp_ratio = mlp_ratio
  234. self.mlp = Mlp(
  235. in_features=dim,
  236. hidden_features=mlp_hidden_dim,
  237. act_layer=act_layer,
  238. drop=drop,
  239. )
  240. self.prenorm = prenorm
  241. def forward(self, x):
  242. if self.prenorm:
  243. x = self.norm1(x + self.drop_path(self.mixer(x)))
  244. x = self.norm2(x + self.drop_path(self.mlp(x)))
  245. else:
  246. x = x + self.drop_path(self.mixer(self.norm1(x)))
  247. x = x + self.drop_path(self.mlp(self.norm2(x)))
  248. return x
  249. class PatchEmbed(nn.Layer):
  250. """Image to Patch Embedding"""
  251. def __init__(
  252. self,
  253. img_size=[32, 100],
  254. in_channels=3,
  255. embed_dim=768,
  256. sub_num=2,
  257. patch_size=[4, 4],
  258. mode="pope",
  259. ):
  260. super().__init__()
  261. num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
  262. self.img_size = img_size
  263. self.num_patches = num_patches
  264. self.embed_dim = embed_dim
  265. self.norm = None
  266. if mode == "pope":
  267. if sub_num == 2:
  268. self.proj = nn.Sequential(
  269. ConvBNLayer(
  270. in_channels=in_channels,
  271. out_channels=embed_dim // 2,
  272. kernel_size=3,
  273. stride=2,
  274. padding=1,
  275. act=nn.GELU,
  276. bias_attr=None,
  277. ),
  278. ConvBNLayer(
  279. in_channels=embed_dim // 2,
  280. out_channels=embed_dim,
  281. kernel_size=3,
  282. stride=2,
  283. padding=1,
  284. act=nn.GELU,
  285. bias_attr=None,
  286. ),
  287. )
  288. if sub_num == 3:
  289. self.proj = nn.Sequential(
  290. ConvBNLayer(
  291. in_channels=in_channels,
  292. out_channels=embed_dim // 4,
  293. kernel_size=3,
  294. stride=2,
  295. padding=1,
  296. act=nn.GELU,
  297. bias_attr=None,
  298. ),
  299. ConvBNLayer(
  300. in_channels=embed_dim // 4,
  301. out_channels=embed_dim // 2,
  302. kernel_size=3,
  303. stride=2,
  304. padding=1,
  305. act=nn.GELU,
  306. bias_attr=None,
  307. ),
  308. ConvBNLayer(
  309. in_channels=embed_dim // 2,
  310. out_channels=embed_dim,
  311. kernel_size=3,
  312. stride=2,
  313. padding=1,
  314. act=nn.GELU,
  315. bias_attr=None,
  316. ),
  317. )
  318. elif mode == "linear":
  319. self.proj = nn.Conv2D(
  320. 1, embed_dim, kernel_size=patch_size, stride=patch_size
  321. )
  322. self.num_patches = (
  323. img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
  324. )
  325. def forward(self, x):
  326. B, C, H, W = x.shape
  327. assert (
  328. H == self.img_size[0] and W == self.img_size[1]
  329. ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  330. x = self.proj(x).flatten(2).transpose((0, 2, 1))
  331. return x
  332. class SubSample(nn.Layer):
  333. def __init__(
  334. self,
  335. in_channels,
  336. out_channels,
  337. types="Pool",
  338. stride=[2, 1],
  339. sub_norm="nn.LayerNorm",
  340. act=None,
  341. ):
  342. super().__init__()
  343. self.types = types
  344. if types == "Pool":
  345. self.avgpool = nn.AvgPool2D(
  346. kernel_size=[3, 5], stride=stride, padding=[1, 2]
  347. )
  348. self.maxpool = nn.MaxPool2D(
  349. kernel_size=[3, 5], stride=stride, padding=[1, 2]
  350. )
  351. self.proj = nn.Linear(in_channels, out_channels)
  352. else:
  353. self.conv = nn.Conv2D(
  354. in_channels,
  355. out_channels,
  356. kernel_size=3,
  357. stride=stride,
  358. padding=1,
  359. weight_attr=ParamAttr(initializer=KaimingNormal()),
  360. )
  361. self.norm = eval(sub_norm)(out_channels)
  362. if act is not None:
  363. self.act = act()
  364. else:
  365. self.act = None
  366. def forward(self, x):
  367. if self.types == "Pool":
  368. x1 = self.avgpool(x)
  369. x2 = self.maxpool(x)
  370. x = (x1 + x2) * 0.5
  371. out = self.proj(x.flatten(2).transpose((0, 2, 1)))
  372. else:
  373. x = self.conv(x)
  374. out = x.flatten(2).transpose((0, 2, 1))
  375. out = self.norm(out)
  376. if self.act is not None:
  377. out = self.act(out)
  378. return out
  379. class SVTRNet(nn.Layer):
  380. def __init__(
  381. self,
  382. img_size=[32, 100],
  383. in_channels=3,
  384. embed_dim=[64, 128, 256],
  385. depth=[3, 6, 3],
  386. num_heads=[2, 4, 8],
  387. mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
  388. local_mixer=[[7, 11], [7, 11], [7, 11]],
  389. patch_merging="Conv", # Conv, Pool, None
  390. mlp_ratio=4,
  391. qkv_bias=True,
  392. qk_scale=None,
  393. drop_rate=0.0,
  394. last_drop=0.1,
  395. attn_drop_rate=0.0,
  396. drop_path_rate=0.1,
  397. norm_layer="nn.LayerNorm",
  398. sub_norm="nn.LayerNorm",
  399. epsilon=1e-6,
  400. out_channels=192,
  401. out_char_num=25,
  402. block_unit="Block",
  403. act="nn.GELU",
  404. last_stage=True,
  405. sub_num=2,
  406. prenorm=True,
  407. use_lenhead=False,
  408. **kwargs,
  409. ):
  410. super().__init__()
  411. self.img_size = img_size
  412. self.embed_dim = embed_dim
  413. self.out_channels = out_channels
  414. self.prenorm = prenorm
  415. patch_merging = (
  416. None
  417. if patch_merging != "Conv" and patch_merging != "Pool"
  418. else patch_merging
  419. )
  420. self.patch_embed = PatchEmbed(
  421. img_size=img_size,
  422. in_channels=in_channels,
  423. embed_dim=embed_dim[0],
  424. sub_num=sub_num,
  425. )
  426. num_patches = self.patch_embed.num_patches
  427. self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
  428. self.pos_embed = self.create_parameter(
  429. shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_
  430. )
  431. self.add_parameter("pos_embed", self.pos_embed)
  432. self.pos_drop = nn.Dropout(p=drop_rate)
  433. Block_unit = eval(block_unit)
  434. dpr = np.linspace(0, drop_path_rate, sum(depth))
  435. self.blocks1 = nn.LayerList(
  436. [
  437. Block_unit(
  438. dim=embed_dim[0],
  439. num_heads=num_heads[0],
  440. mixer=mixer[0 : depth[0]][i],
  441. HW=self.HW,
  442. local_mixer=local_mixer[0],
  443. mlp_ratio=mlp_ratio,
  444. qkv_bias=qkv_bias,
  445. qk_scale=qk_scale,
  446. drop=drop_rate,
  447. act_layer=eval(act),
  448. attn_drop=attn_drop_rate,
  449. drop_path=dpr[0 : depth[0]][i],
  450. norm_layer=norm_layer,
  451. epsilon=epsilon,
  452. prenorm=prenorm,
  453. )
  454. for i in range(depth[0])
  455. ]
  456. )
  457. if patch_merging is not None:
  458. self.sub_sample1 = SubSample(
  459. embed_dim[0],
  460. embed_dim[1],
  461. sub_norm=sub_norm,
  462. stride=[2, 1],
  463. types=patch_merging,
  464. )
  465. HW = [self.HW[0] // 2, self.HW[1]]
  466. else:
  467. HW = self.HW
  468. self.patch_merging = patch_merging
  469. self.blocks2 = nn.LayerList(
  470. [
  471. Block_unit(
  472. dim=embed_dim[1],
  473. num_heads=num_heads[1],
  474. mixer=mixer[depth[0] : depth[0] + depth[1]][i],
  475. HW=HW,
  476. local_mixer=local_mixer[1],
  477. mlp_ratio=mlp_ratio,
  478. qkv_bias=qkv_bias,
  479. qk_scale=qk_scale,
  480. drop=drop_rate,
  481. act_layer=eval(act),
  482. attn_drop=attn_drop_rate,
  483. drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
  484. norm_layer=norm_layer,
  485. epsilon=epsilon,
  486. prenorm=prenorm,
  487. )
  488. for i in range(depth[1])
  489. ]
  490. )
  491. if patch_merging is not None:
  492. self.sub_sample2 = SubSample(
  493. embed_dim[1],
  494. embed_dim[2],
  495. sub_norm=sub_norm,
  496. stride=[2, 1],
  497. types=patch_merging,
  498. )
  499. HW = [self.HW[0] // 4, self.HW[1]]
  500. else:
  501. HW = self.HW
  502. self.blocks3 = nn.LayerList(
  503. [
  504. Block_unit(
  505. dim=embed_dim[2],
  506. num_heads=num_heads[2],
  507. mixer=mixer[depth[0] + depth[1] :][i],
  508. HW=HW,
  509. local_mixer=local_mixer[2],
  510. mlp_ratio=mlp_ratio,
  511. qkv_bias=qkv_bias,
  512. qk_scale=qk_scale,
  513. drop=drop_rate,
  514. act_layer=eval(act),
  515. attn_drop=attn_drop_rate,
  516. drop_path=dpr[depth[0] + depth[1] :][i],
  517. norm_layer=norm_layer,
  518. epsilon=epsilon,
  519. prenorm=prenorm,
  520. )
  521. for i in range(depth[2])
  522. ]
  523. )
  524. self.last_stage = last_stage
  525. if last_stage:
  526. self.avg_pool = nn.AdaptiveAvgPool2D([1, out_char_num])
  527. self.last_conv = nn.Conv2D(
  528. in_channels=embed_dim[2],
  529. out_channels=self.out_channels,
  530. kernel_size=1,
  531. stride=1,
  532. padding=0,
  533. bias_attr=False,
  534. )
  535. self.hardswish = nn.Hardswish()
  536. self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
  537. if not prenorm:
  538. self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
  539. self.use_lenhead = use_lenhead
  540. if use_lenhead:
  541. self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
  542. self.hardswish_len = nn.Hardswish()
  543. self.dropout_len = nn.Dropout(p=last_drop, mode="downscale_in_infer")
  544. trunc_normal_(self.pos_embed)
  545. self.apply(self._init_weights)
  546. def _init_weights(self, m):
  547. if isinstance(m, nn.Linear):
  548. trunc_normal_(m.weight)
  549. if isinstance(m, nn.Linear) and m.bias is not None:
  550. zeros_(m.bias)
  551. elif isinstance(m, nn.LayerNorm):
  552. zeros_(m.bias)
  553. ones_(m.weight)
  554. def forward_features(self, x):
  555. x = self.patch_embed(x)
  556. x = x + self.pos_embed
  557. x = self.pos_drop(x)
  558. for blk in self.blocks1:
  559. x = blk(x)
  560. if self.patch_merging is not None:
  561. x = self.sub_sample1(
  562. x.transpose([0, 2, 1]).reshape(
  563. [0, self.embed_dim[0], self.HW[0], self.HW[1]]
  564. )
  565. )
  566. for blk in self.blocks2:
  567. x = blk(x)
  568. if self.patch_merging is not None:
  569. x = self.sub_sample2(
  570. x.transpose([0, 2, 1]).reshape(
  571. [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
  572. )
  573. )
  574. for blk in self.blocks3:
  575. x = blk(x)
  576. if not self.prenorm:
  577. x = self.norm(x)
  578. return x
  579. def forward(self, x):
  580. x = self.forward_features(x)
  581. if self.use_lenhead:
  582. len_x = self.len_conv(x.mean(1))
  583. len_x = self.dropout_len(self.hardswish_len(len_x))
  584. if self.last_stage:
  585. if self.patch_merging is not None:
  586. h = self.HW[0] // 4
  587. else:
  588. h = self.HW[0]
  589. x = self.avg_pool(
  590. x.transpose([0, 2, 1]).reshape([0, self.embed_dim[2], h, self.HW[1]])
  591. )
  592. x = self.last_conv(x)
  593. x = self.hardswish(x)
  594. x = self.dropout(x)
  595. if self.use_lenhead:
  596. return x, len_x
  597. return x