rec_hybridvit.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. from itertools import repeat
  22. import collections
  23. import math
  24. from functools import partial
  25. import paddle
  26. import paddle.nn as nn
  27. import paddle.nn.functional as F
  28. from ppocr.modeling.backbones.rec_resnetv2 import (
  29. ResNetV2,
  30. StdConv2dSame,
  31. DropPath,
  32. get_padding,
  33. )
  34. from paddle.nn.initializer import (
  35. TruncatedNormal,
  36. Constant,
  37. Normal,
  38. KaimingUniform,
  39. XavierUniform,
  40. )
  41. normal_ = Normal(mean=0.0, std=1e-6)
  42. zeros_ = Constant(value=0.0)
  43. ones_ = Constant(value=1.0)
  44. kaiming_normal_ = KaimingUniform(nonlinearity="relu")
  45. trunc_normal_ = TruncatedNormal(std=0.02)
  46. xavier_uniform_ = XavierUniform()
  47. def _ntuple(n):
  48. def parse(x):
  49. if isinstance(x, collections.abc.Iterable):
  50. return x
  51. return tuple(repeat(x, n))
  52. return parse
  53. to_1tuple = _ntuple(1)
  54. to_2tuple = _ntuple(2)
  55. to_3tuple = _ntuple(3)
  56. to_4tuple = _ntuple(4)
  57. to_ntuple = _ntuple
  58. class Conv2dAlign(nn.Conv2D):
  59. """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
  60. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
  61. https://arxiv.org/abs/1903.10520v2
  62. """
  63. def __init__(
  64. self,
  65. in_channel,
  66. out_channels,
  67. kernel_size,
  68. stride=1,
  69. padding=0,
  70. dilation=1,
  71. groups=1,
  72. bias=True,
  73. eps=1e-6,
  74. ):
  75. super().__init__(
  76. in_channel,
  77. out_channels,
  78. kernel_size,
  79. stride=stride,
  80. padding=padding,
  81. dilation=dilation,
  82. groups=groups,
  83. bias_attr=bias,
  84. weight_attr=True,
  85. )
  86. self.eps = eps
  87. def forward(self, x):
  88. x = F.conv2d(
  89. x,
  90. self.weight,
  91. self.bias,
  92. self._stride,
  93. self._padding,
  94. self._dilation,
  95. self._groups,
  96. )
  97. return x
  98. class HybridEmbed(nn.Layer):
  99. """CNN Feature Map Embedding
  100. Extract feature map from CNN, flatten, project to embedding dim.
  101. """
  102. def __init__(
  103. self,
  104. backbone,
  105. img_size=224,
  106. patch_size=1,
  107. feature_size=None,
  108. in_chans=3,
  109. embed_dim=768,
  110. ):
  111. super().__init__()
  112. assert isinstance(backbone, nn.Layer)
  113. img_size = to_2tuple(img_size)
  114. patch_size = to_2tuple(patch_size)
  115. self.img_size = img_size
  116. self.patch_size = patch_size
  117. self.backbone = backbone
  118. feature_dim = 1024
  119. feature_size = (42, 12)
  120. patch_size = (1, 1)
  121. assert (
  122. feature_size[0] % patch_size[0] == 0
  123. and feature_size[1] % patch_size[1] == 0
  124. )
  125. self.grid_size = (
  126. feature_size[0] // patch_size[0],
  127. feature_size[1] // patch_size[1],
  128. )
  129. self.num_patches = self.grid_size[0] * self.grid_size[1]
  130. self.proj = nn.Conv2D(
  131. feature_dim,
  132. embed_dim,
  133. kernel_size=patch_size,
  134. stride=patch_size,
  135. weight_attr=True,
  136. bias_attr=True,
  137. )
  138. def forward(self, x):
  139. x = self.backbone(x)
  140. if isinstance(x, (list, tuple)):
  141. x = x[-1] # last feature if backbone outputs list/tuple of features
  142. x = self.proj(x).flatten(2).transpose([0, 2, 1])
  143. return x
  144. class myLinear(nn.Linear):
  145. def __init__(self, in_channel, out_channels, weight_attr=True, bias_attr=True):
  146. super().__init__(
  147. in_channel, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
  148. )
  149. def forward(self, x):
  150. return paddle.matmul(x, self.weight, transpose_y=True) + self.bias
  151. class Attention(nn.Layer):
  152. def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
  153. super().__init__()
  154. self.num_heads = num_heads
  155. head_dim = dim // num_heads
  156. self.scale = head_dim**-0.5
  157. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  158. self.attn_drop = nn.Dropout(attn_drop)
  159. self.proj = myLinear(dim, dim, weight_attr=True, bias_attr=True)
  160. self.proj_drop = nn.Dropout(proj_drop)
  161. def forward(self, x):
  162. B, N, C = x.shape
  163. qkv = (
  164. self.qkv(x)
  165. .reshape([B, N, 3, self.num_heads, C // self.num_heads])
  166. .transpose([2, 0, 3, 1, 4])
  167. )
  168. q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  169. attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
  170. attn = F.softmax(attn, axis=-1)
  171. attn = self.attn_drop(attn)
  172. x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C])
  173. x = self.proj(x)
  174. x = self.proj_drop(x)
  175. return x
  176. class Mlp(nn.Layer):
  177. """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
  178. def __init__(
  179. self,
  180. in_features,
  181. hidden_features=None,
  182. out_features=None,
  183. act_layer=nn.GELU,
  184. drop=0.0,
  185. ):
  186. super().__init__()
  187. out_features = out_features or in_features
  188. hidden_features = hidden_features or in_features
  189. drop_probs = to_2tuple(drop)
  190. self.fc1 = nn.Linear(in_features, hidden_features)
  191. self.act = act_layer()
  192. self.drop1 = nn.Dropout(drop_probs[0])
  193. self.fc2 = nn.Linear(hidden_features, out_features)
  194. self.drop2 = nn.Dropout(drop_probs[1])
  195. def forward(self, x):
  196. x = self.fc1(x)
  197. x = self.act(x)
  198. x = self.drop1(x)
  199. x = self.fc2(x)
  200. x = self.drop2(x)
  201. return x
  202. class Block(nn.Layer):
  203. def __init__(
  204. self,
  205. dim,
  206. num_heads,
  207. mlp_ratio=4.0,
  208. qkv_bias=False,
  209. drop=0.0,
  210. attn_drop=0.0,
  211. drop_path=0.0,
  212. act_layer=nn.GELU,
  213. norm_layer=nn.LayerNorm,
  214. ):
  215. super().__init__()
  216. self.norm1 = norm_layer(dim)
  217. self.attn = Attention(
  218. dim,
  219. num_heads=num_heads,
  220. qkv_bias=qkv_bias,
  221. attn_drop=attn_drop,
  222. proj_drop=drop,
  223. )
  224. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  225. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  226. self.norm2 = norm_layer(dim)
  227. mlp_hidden_dim = int(dim * mlp_ratio)
  228. self.mlp = Mlp(
  229. in_features=dim,
  230. hidden_features=mlp_hidden_dim,
  231. act_layer=act_layer,
  232. drop=drop,
  233. )
  234. def forward(self, x):
  235. x = x + self.drop_path(self.attn(self.norm1(x)))
  236. x = x + self.drop_path(self.mlp(self.norm2(x)))
  237. return x
  238. class HybridTransformer(nn.Layer):
  239. """Implementation of HybridTransformer.
  240. Args:
  241. x: input images with shape [N, 1, H, W]
  242. label: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
  243. attention_mask: LaTeX-OCR attention mask with shape [N, L] , L is the max sequence length
  244. Returns:
  245. The encoded features with shape [N, 1, H//16, W//16]
  246. """
  247. def __init__(
  248. self,
  249. backbone_layers=[2, 3, 7],
  250. input_channel=1,
  251. is_predict=False,
  252. is_export=False,
  253. img_size=(224, 224),
  254. patch_size=16,
  255. num_classes=1000,
  256. embed_dim=768,
  257. depth=12,
  258. num_heads=12,
  259. mlp_ratio=4.0,
  260. qkv_bias=True,
  261. representation_size=None,
  262. distilled=False,
  263. drop_rate=0.0,
  264. attn_drop_rate=0.0,
  265. drop_path_rate=0.0,
  266. embed_layer=None,
  267. norm_layer=None,
  268. act_layer=None,
  269. weight_init="",
  270. **kwargs,
  271. ):
  272. super(HybridTransformer, self).__init__()
  273. self.num_classes = num_classes
  274. self.num_features = self.embed_dim = (
  275. embed_dim # num_features for consistency with other models
  276. )
  277. self.num_tokens = 2 if distilled else 1
  278. norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
  279. act_layer = act_layer or nn.GELU
  280. self.height, self.width = img_size
  281. self.patch_size = patch_size
  282. backbone = ResNetV2(
  283. layers=backbone_layers,
  284. num_classes=0,
  285. global_pool="",
  286. in_chans=input_channel,
  287. preact=False,
  288. stem_type="same",
  289. conv_layer=StdConv2dSame,
  290. is_export=is_export,
  291. )
  292. min_patch_size = 2 ** (len(backbone_layers) + 1)
  293. self.patch_embed = HybridEmbed(
  294. img_size=img_size,
  295. patch_size=patch_size // min_patch_size,
  296. in_chans=input_channel,
  297. embed_dim=embed_dim,
  298. backbone=backbone,
  299. )
  300. num_patches = self.patch_embed.num_patches
  301. self.cls_token = paddle.create_parameter([1, 1, embed_dim], dtype="float32")
  302. self.dist_token = (
  303. paddle.create_parameter(
  304. [1, 1, embed_dim],
  305. dtype="float32",
  306. )
  307. if distilled
  308. else None
  309. )
  310. self.pos_embed = paddle.create_parameter(
  311. [1, num_patches + self.num_tokens, embed_dim], dtype="float32"
  312. )
  313. self.pos_drop = nn.Dropout(p=drop_rate)
  314. zeros_(self.cls_token)
  315. if self.dist_token is not None:
  316. zeros_(self.dist_token)
  317. zeros_(self.pos_embed)
  318. dpr = [
  319. x.item() for x in paddle.linspace(0, drop_path_rate, depth)
  320. ] # stochastic depth decay rule
  321. self.blocks = nn.Sequential(
  322. *[
  323. Block(
  324. dim=embed_dim,
  325. num_heads=num_heads,
  326. mlp_ratio=mlp_ratio,
  327. qkv_bias=qkv_bias,
  328. drop=drop_rate,
  329. attn_drop=attn_drop_rate,
  330. drop_path=dpr[i],
  331. norm_layer=norm_layer,
  332. act_layer=act_layer,
  333. )
  334. for i in range(depth)
  335. ]
  336. )
  337. self.norm = norm_layer(embed_dim)
  338. # Representation layer
  339. if representation_size and not distilled:
  340. self.num_features = representation_size
  341. self.pre_logits = nn.Sequential(
  342. ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh())
  343. )
  344. else:
  345. self.pre_logits = nn.Identity()
  346. # Classifier head(s)
  347. self.head = (
  348. nn.Linear(self.num_features, num_classes)
  349. if num_classes > 0
  350. else nn.Identity()
  351. )
  352. self.head_dist = None
  353. if distilled:
  354. self.head_dist = (
  355. nn.Linear(self.embed_dim, self.num_classes)
  356. if num_classes > 0
  357. else nn.Identity()
  358. )
  359. self.init_weights(weight_init)
  360. self.out_channels = embed_dim
  361. self.is_predict = is_predict
  362. self.is_export = is_export
  363. def init_weights(self, mode=""):
  364. assert mode in ("jax", "jax_nlhb", "nlhb", "")
  365. head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
  366. trunc_normal_(self.pos_embed)
  367. trunc_normal_(self.cls_token)
  368. self.apply(_init_vit_weights)
  369. def _init_weights(self, m):
  370. # this fn left here for compat with downstream users
  371. _init_vit_weights(m)
  372. def load_pretrained(self, checkpoint_path, prefix=""):
  373. raise NotImplementedError
  374. def no_weight_decay(self):
  375. return {"pos_embed", "cls_token", "dist_token"}
  376. def get_classifier(self):
  377. if self.dist_token is None:
  378. return self.head
  379. else:
  380. return self.head, self.head_dist
  381. def reset_classifier(self, num_classes, global_pool=""):
  382. self.num_classes = num_classes
  383. self.head = (
  384. nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  385. )
  386. if self.num_tokens == 2:
  387. self.head_dist = (
  388. nn.Linear(self.embed_dim, self.num_classes)
  389. if num_classes > 0
  390. else nn.Identity()
  391. )
  392. def forward_features(self, x):
  393. B, c, h, w = x.shape
  394. x = self.patch_embed(x)
  395. cls_tokens = self.cls_token.expand(
  396. [B, -1, -1]
  397. ) # stole cls_tokens impl from Phil Wang, thanks
  398. x = paddle.concat((cls_tokens, x), axis=1)
  399. h, w = h // self.patch_size, w // self.patch_size
  400. repeat_tensor = (
  401. paddle.arange(h) * (self.width // self.patch_size - w)
  402. ).reshape([-1, 1])
  403. repeat_tensor = paddle.repeat_interleave(
  404. repeat_tensor, paddle.to_tensor(w), axis=1
  405. ).reshape([-1])
  406. pos_emb_ind = repeat_tensor + paddle.arange(h * w)
  407. pos_emb_ind = paddle.concat(
  408. (paddle.zeros([1], dtype="int64"), pos_emb_ind + 1), axis=0
  409. ).cast(paddle.int64)
  410. x += self.pos_embed[:, pos_emb_ind]
  411. x = self.pos_drop(x)
  412. for blk in self.blocks:
  413. x = blk(x)
  414. x = self.norm(x)
  415. return x
  416. def forward(self, input_data):
  417. if self.training:
  418. x, label, attention_mask = input_data
  419. else:
  420. if isinstance(input_data, list):
  421. x = input_data[0]
  422. else:
  423. x = input_data
  424. x = self.forward_features(x)
  425. x = self.head(x)
  426. if self.training:
  427. return x, label, attention_mask
  428. else:
  429. return x
  430. def _init_vit_weights(
  431. module: nn.Layer, name: str = "", head_bias: float = 0.0, jax_impl: bool = False
  432. ):
  433. """ViT weight initialization
  434. * When called without n, head_bias, jax_impl args it will behave exactly the same
  435. as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
  436. * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
  437. """
  438. if isinstance(module, nn.Linear):
  439. if name.startswith("head"):
  440. zeros_(module.weight)
  441. constant_ = Constant(value=head_bias)
  442. constant_(module.bias, head_bias)
  443. elif name.startswith("pre_logits"):
  444. zeros_(module.bias)
  445. else:
  446. if jax_impl:
  447. xavier_uniform_(module.weight)
  448. if module.bias is not None:
  449. if "mlp" in name:
  450. normal_(module.bias)
  451. else:
  452. zeros_(module.bias)
  453. else:
  454. trunc_normal_(module.weight)
  455. if module.bias is not None:
  456. zeros_(module.bias)
  457. elif jax_impl and isinstance(module, nn.Conv2D):
  458. # NOTE conv was left to pytorch default in my original init
  459. if module.bias is not None:
  460. zeros_(module.bias)
  461. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2D)):
  462. zeros_(module.bias)
  463. ones_(module.weight)