rec_vitstr.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/roatienza/deep-text-recognition-benchmark/blob/master/modules/vitstr.py
  17. """
  18. import numpy as np
  19. import paddle
  20. import paddle.nn as nn
  21. from ppocr.modeling.backbones.rec_svtrnet import (
  22. Block,
  23. PatchEmbed,
  24. zeros_,
  25. trunc_normal_,
  26. ones_,
  27. )
  28. scale_dim_heads = {"tiny": [192, 3], "small": [384, 6], "base": [768, 12]}
  29. class ViTSTR(nn.Layer):
  30. def __init__(
  31. self,
  32. img_size=[224, 224],
  33. in_channels=1,
  34. scale="tiny",
  35. seqlen=27,
  36. patch_size=[16, 16],
  37. embed_dim=None,
  38. depth=12,
  39. num_heads=None,
  40. mlp_ratio=4,
  41. qkv_bias=True,
  42. qk_scale=None,
  43. drop_path_rate=0.0,
  44. drop_rate=0.0,
  45. attn_drop_rate=0.0,
  46. norm_layer="nn.LayerNorm",
  47. act_layer="nn.GELU",
  48. epsilon=1e-6,
  49. out_channels=None,
  50. **kwargs,
  51. ):
  52. super().__init__()
  53. self.seqlen = seqlen
  54. embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[scale][0]
  55. num_heads = num_heads if num_heads is not None else scale_dim_heads[scale][1]
  56. out_channels = out_channels if out_channels is not None else embed_dim
  57. self.patch_embed = PatchEmbed(
  58. img_size=img_size,
  59. in_channels=in_channels,
  60. embed_dim=embed_dim,
  61. patch_size=patch_size,
  62. mode="linear",
  63. )
  64. num_patches = self.patch_embed.num_patches
  65. self.pos_embed = self.create_parameter(
  66. shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_
  67. )
  68. self.add_parameter("pos_embed", self.pos_embed)
  69. self.cls_token = self.create_parameter(
  70. shape=[1, 1, embed_dim], default_initializer=zeros_
  71. )
  72. self.add_parameter("cls_token", self.cls_token)
  73. self.pos_drop = nn.Dropout(p=drop_rate)
  74. dpr = np.linspace(0, drop_path_rate, depth)
  75. self.blocks = nn.LayerList(
  76. [
  77. Block(
  78. dim=embed_dim,
  79. num_heads=num_heads,
  80. mlp_ratio=mlp_ratio,
  81. qkv_bias=qkv_bias,
  82. qk_scale=qk_scale,
  83. drop=drop_rate,
  84. attn_drop=attn_drop_rate,
  85. drop_path=dpr[i],
  86. norm_layer=norm_layer,
  87. act_layer=eval(act_layer),
  88. epsilon=epsilon,
  89. prenorm=False,
  90. )
  91. for i in range(depth)
  92. ]
  93. )
  94. self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
  95. self.out_channels = out_channels
  96. trunc_normal_(self.pos_embed)
  97. trunc_normal_(self.cls_token)
  98. self.apply(self._init_weights)
  99. def _init_weights(self, m):
  100. if isinstance(m, nn.Linear):
  101. trunc_normal_(m.weight)
  102. if isinstance(m, nn.Linear) and m.bias is not None:
  103. zeros_(m.bias)
  104. elif isinstance(m, nn.LayerNorm):
  105. zeros_(m.bias)
  106. ones_(m.weight)
  107. def forward_features(self, x):
  108. B = x.shape[0]
  109. x = self.patch_embed(x)
  110. cls_tokens = paddle.tile(self.cls_token, repeat_times=[B, 1, 1])
  111. x = paddle.concat((cls_tokens, x), axis=1)
  112. x = x + self.pos_embed
  113. x = self.pos_drop(x)
  114. for blk in self.blocks:
  115. x = blk(x)
  116. x = self.norm(x)
  117. return x
  118. def forward(self, x):
  119. x = self.forward_features(x)
  120. x = x[:, : self.seqlen]
  121. return x.transpose([0, 2, 1]).unsqueeze(2)