rec_vit.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import ParamAttr
  15. from paddle.nn.initializer import KaimingNormal
  16. import numpy as np
  17. import paddle
  18. import paddle.nn as nn
  19. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  20. trunc_normal_ = TruncatedNormal(std=0.02)
  21. normal_ = Normal
  22. zeros_ = Constant(value=0.0)
  23. ones_ = Constant(value=1.0)
  24. def drop_path(x, drop_prob=0.0, training=False):
  25. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  26. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  27. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  28. """
  29. if drop_prob == 0.0 or not training:
  30. return x
  31. keep_prob = paddle.to_tensor(1 - drop_prob)
  32. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  33. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  34. random_tensor = paddle.floor(random_tensor) # binarize
  35. output = x.divide(keep_prob) * random_tensor
  36. return output
  37. class DropPath(nn.Layer):
  38. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  39. def __init__(self, drop_prob=None):
  40. super(DropPath, self).__init__()
  41. self.drop_prob = drop_prob
  42. def forward(self, x):
  43. return drop_path(x, self.drop_prob, self.training)
  44. class Identity(nn.Layer):
  45. def __init__(self):
  46. super(Identity, self).__init__()
  47. def forward(self, input):
  48. return input
  49. class Mlp(nn.Layer):
  50. def __init__(
  51. self,
  52. in_features,
  53. hidden_features=None,
  54. out_features=None,
  55. act_layer=nn.GELU,
  56. drop=0.0,
  57. ):
  58. super().__init__()
  59. out_features = out_features or in_features
  60. hidden_features = hidden_features or in_features
  61. self.fc1 = nn.Linear(in_features, hidden_features)
  62. self.act = act_layer()
  63. self.fc2 = nn.Linear(hidden_features, out_features)
  64. self.drop = nn.Dropout(drop)
  65. def forward(self, x):
  66. x = self.fc1(x)
  67. x = self.act(x)
  68. x = self.drop(x)
  69. x = self.fc2(x)
  70. x = self.drop(x)
  71. return x
  72. class Attention(nn.Layer):
  73. def __init__(
  74. self,
  75. dim,
  76. num_heads=8,
  77. qkv_bias=False,
  78. qk_scale=None,
  79. attn_drop=0.0,
  80. proj_drop=0.0,
  81. ):
  82. super().__init__()
  83. self.num_heads = num_heads
  84. self.dim = dim
  85. head_dim = dim // num_heads
  86. self.scale = qk_scale or head_dim**-0.5
  87. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  88. self.attn_drop = nn.Dropout(attn_drop)
  89. self.proj = nn.Linear(dim, dim)
  90. self.proj_drop = nn.Dropout(proj_drop)
  91. def forward(self, x):
  92. qkv = paddle.reshape(
  93. self.qkv(x), (0, -1, 3, self.num_heads, self.dim // self.num_heads)
  94. ).transpose((2, 0, 3, 1, 4))
  95. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  96. attn = q.matmul(k.transpose((0, 1, 3, 2)))
  97. attn = nn.functional.softmax(attn, axis=-1)
  98. attn = self.attn_drop(attn)
  99. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
  100. x = self.proj(x)
  101. x = self.proj_drop(x)
  102. return x
  103. class Block(nn.Layer):
  104. def __init__(
  105. self,
  106. dim,
  107. num_heads,
  108. mlp_ratio=4.0,
  109. qkv_bias=False,
  110. qk_scale=None,
  111. drop=0.0,
  112. attn_drop=0.0,
  113. drop_path=0.0,
  114. act_layer=nn.GELU,
  115. norm_layer="nn.LayerNorm",
  116. epsilon=1e-6,
  117. prenorm=True,
  118. ):
  119. super().__init__()
  120. if isinstance(norm_layer, str):
  121. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  122. else:
  123. self.norm1 = norm_layer(dim)
  124. self.mixer = Attention(
  125. dim,
  126. num_heads=num_heads,
  127. qkv_bias=qkv_bias,
  128. qk_scale=qk_scale,
  129. attn_drop=attn_drop,
  130. proj_drop=drop,
  131. )
  132. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  133. if isinstance(norm_layer, str):
  134. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  135. else:
  136. self.norm2 = norm_layer(dim)
  137. mlp_hidden_dim = int(dim * mlp_ratio)
  138. self.mlp_ratio = mlp_ratio
  139. self.mlp = Mlp(
  140. in_features=dim,
  141. hidden_features=mlp_hidden_dim,
  142. act_layer=act_layer,
  143. drop=drop,
  144. )
  145. self.prenorm = prenorm
  146. def forward(self, x):
  147. if self.prenorm:
  148. x = self.norm1(x + self.drop_path(self.mixer(x)))
  149. x = self.norm2(x + self.drop_path(self.mlp(x)))
  150. else:
  151. x = x + self.drop_path(self.mixer(self.norm1(x)))
  152. x = x + self.drop_path(self.mlp(self.norm2(x)))
  153. return x
  154. class ViT(nn.Layer):
  155. def __init__(
  156. self,
  157. img_size=[32, 128],
  158. patch_size=[4, 4],
  159. in_channels=3,
  160. embed_dim=384,
  161. depth=12,
  162. num_heads=6,
  163. mlp_ratio=4,
  164. qkv_bias=False,
  165. qk_scale=None,
  166. drop_rate=0.0,
  167. attn_drop_rate=0.0,
  168. drop_path_rate=0.1,
  169. norm_layer="nn.LayerNorm",
  170. epsilon=1e-6,
  171. act="nn.GELU",
  172. prenorm=False,
  173. **kwargs,
  174. ):
  175. super().__init__()
  176. self.embed_dim = embed_dim
  177. self.out_channels = embed_dim
  178. self.prenorm = prenorm
  179. self.patch_embed = nn.Conv2D(
  180. in_channels, embed_dim, patch_size, patch_size, padding=(0, 0)
  181. )
  182. self.pos_embed = self.create_parameter(
  183. shape=[1, 257, embed_dim], default_initializer=zeros_
  184. )
  185. self.add_parameter("pos_embed", self.pos_embed)
  186. self.pos_drop = nn.Dropout(p=drop_rate)
  187. dpr = np.linspace(0, drop_path_rate, depth)
  188. self.blocks1 = nn.LayerList(
  189. [
  190. Block(
  191. dim=embed_dim,
  192. num_heads=num_heads,
  193. mlp_ratio=mlp_ratio,
  194. qkv_bias=qkv_bias,
  195. qk_scale=qk_scale,
  196. drop=drop_rate,
  197. act_layer=eval(act),
  198. attn_drop=attn_drop_rate,
  199. drop_path=dpr[i],
  200. norm_layer=norm_layer,
  201. epsilon=epsilon,
  202. prenorm=prenorm,
  203. )
  204. for i in range(depth)
  205. ]
  206. )
  207. if not prenorm:
  208. self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
  209. self.avg_pool = nn.AdaptiveAvgPool2D([1, 25])
  210. self.last_conv = nn.Conv2D(
  211. in_channels=embed_dim,
  212. out_channels=self.out_channels,
  213. kernel_size=1,
  214. stride=1,
  215. padding=0,
  216. bias_attr=False,
  217. )
  218. self.hardswish = nn.Hardswish()
  219. self.dropout = nn.Dropout(p=0.1, mode="downscale_in_infer")
  220. trunc_normal_(self.pos_embed)
  221. self.apply(self._init_weights)
  222. def _init_weights(self, m):
  223. if isinstance(m, nn.Linear):
  224. trunc_normal_(m.weight)
  225. if isinstance(m, nn.Linear) and m.bias is not None:
  226. zeros_(m.bias)
  227. elif isinstance(m, nn.LayerNorm):
  228. zeros_(m.bias)
  229. ones_(m.weight)
  230. def forward(self, x):
  231. x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
  232. x = x + self.pos_embed[:, 1:, :] # [:, :x.shape[1], :]
  233. x = self.pos_drop(x)
  234. for blk in self.blocks1:
  235. x = blk(x)
  236. if not self.prenorm:
  237. x = self.norm(x)
  238. x = self.avg_pool(x.transpose([0, 2, 1]).reshape([0, self.embed_dim, -1, 25]))
  239. x = self.last_conv(x)
  240. x = self.hardswish(x)
  241. x = self.dropout(x)
  242. return x