rec_cppd_head.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections import Callable
  19. except:
  20. from collections.abc import Callable
  21. import numpy as np
  22. import paddle
  23. from paddle import nn
  24. from paddle.nn import functional as F
  25. from ppocr.modeling.heads.rec_nrtr_head import Embeddings
  26. from ppocr.modeling.backbones.rec_svtrnet import (
  27. DropPath,
  28. Identity,
  29. trunc_normal_,
  30. zeros_,
  31. ones_,
  32. Mlp,
  33. )
  34. class Attention(nn.Layer):
  35. def __init__(
  36. self,
  37. dim,
  38. num_heads=8,
  39. qkv_bias=False,
  40. qk_scale=None,
  41. attn_drop=0.0,
  42. proj_drop=0.0,
  43. ):
  44. super().__init__()
  45. self.num_heads = num_heads
  46. head_dim = dim // num_heads
  47. self.scale = qk_scale or head_dim**-0.5
  48. self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
  49. self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
  50. self.attn_drop = nn.Dropout(attn_drop)
  51. self.proj = nn.Linear(dim, dim)
  52. self.proj_drop = nn.Dropout(proj_drop)
  53. def forward(self, q, kv):
  54. N, C = kv.shape[1:]
  55. QN = q.shape[1]
  56. q = (
  57. self.q(q)
  58. .reshape([-1, QN, self.num_heads, C // self.num_heads])
  59. .transpose([0, 2, 1, 3])
  60. )
  61. k, v = (
  62. self.kv(kv)
  63. .reshape([-1, N, 2, self.num_heads, C // self.num_heads])
  64. .transpose((2, 0, 3, 1, 4))
  65. )
  66. attn = q.matmul(k.transpose((0, 1, 3, 2))) * self.scale
  67. attn = F.softmax(attn, axis=-1)
  68. attn = self.attn_drop(attn)
  69. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, QN, C))
  70. x = self.proj(x)
  71. x = self.proj_drop(x)
  72. return x
  73. class EdgeDecoderLayer(nn.Layer):
  74. def __init__(
  75. self,
  76. dim,
  77. num_heads,
  78. mlp_ratio=4.0,
  79. qkv_bias=False,
  80. qk_scale=None,
  81. drop=0.0,
  82. attn_drop=0.0,
  83. drop_path=[0.0, 0.0],
  84. act_layer=nn.GELU,
  85. norm_layer="nn.LayerNorm",
  86. epsilon=1e-6,
  87. ):
  88. super().__init__()
  89. self.head_dim = dim // num_heads
  90. self.scale = qk_scale or self.head_dim**-0.5
  91. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  92. self.drop_path1 = DropPath(drop_path[0]) if drop_path[0] > 0.0 else Identity()
  93. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  94. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  95. self.p = nn.Linear(dim, dim)
  96. self.cv = nn.Linear(dim, dim)
  97. self.pv = nn.Linear(dim, dim)
  98. self.dim = dim
  99. self.num_heads = num_heads
  100. self.p_proj = nn.Linear(dim, dim)
  101. mlp_hidden_dim = int(dim * mlp_ratio)
  102. self.mlp_ratio = mlp_ratio
  103. self.mlp = Mlp(
  104. in_features=dim,
  105. hidden_features=mlp_hidden_dim,
  106. act_layer=act_layer,
  107. drop=drop,
  108. )
  109. def forward(self, p, cv, pv):
  110. pN = p.shape[1]
  111. vN = cv.shape[1]
  112. p_shortcut = p
  113. p1 = (
  114. self.p(p)
  115. .reshape([-1, pN, self.num_heads, self.dim // self.num_heads])
  116. .transpose([0, 2, 1, 3])
  117. )
  118. cv1 = (
  119. self.cv(cv)
  120. .reshape([-1, vN, self.num_heads, self.dim // self.num_heads])
  121. .transpose([0, 2, 1, 3])
  122. )
  123. pv1 = (
  124. self.pv(pv)
  125. .reshape([-1, vN, self.num_heads, self.dim // self.num_heads])
  126. .transpose([0, 2, 1, 3])
  127. )
  128. edge = F.softmax(p1.matmul(pv1.transpose((0, 1, 3, 2))), -1) # B h N N
  129. p_c = (edge @ cv1).transpose((0, 2, 1, 3)).reshape((-1, pN, self.dim))
  130. x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
  131. x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
  132. return x
  133. class DecoderLayer(nn.Layer):
  134. def __init__(
  135. self,
  136. dim,
  137. num_heads,
  138. mlp_ratio=4.0,
  139. qkv_bias=False,
  140. qk_scale=None,
  141. drop=0.0,
  142. attn_drop=0.0,
  143. drop_path=0.0,
  144. act_layer=nn.GELU,
  145. norm_layer="nn.LayerNorm",
  146. epsilon=1e-6,
  147. ):
  148. super().__init__()
  149. if isinstance(norm_layer, str):
  150. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  151. self.normkv = eval(norm_layer)(dim, epsilon=epsilon)
  152. elif isinstance(norm_layer, Callable):
  153. self.norm1 = norm_layer(dim)
  154. self.normkv = norm_layer(dim)
  155. else:
  156. raise TypeError("The norm_layer must be str or paddle.nn.LayerNorm class")
  157. self.mixer = Attention(
  158. dim,
  159. num_heads=num_heads,
  160. qkv_bias=qkv_bias,
  161. qk_scale=qk_scale,
  162. attn_drop=attn_drop,
  163. proj_drop=drop,
  164. )
  165. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  166. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  167. if isinstance(norm_layer, str):
  168. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  169. elif isinstance(norm_layer, Callable):
  170. self.norm2 = norm_layer(dim)
  171. else:
  172. raise TypeError("The norm_layer must be str or paddle.nn.layer.Layer class")
  173. mlp_hidden_dim = int(dim * mlp_ratio)
  174. self.mlp_ratio = mlp_ratio
  175. self.mlp = Mlp(
  176. in_features=dim,
  177. hidden_features=mlp_hidden_dim,
  178. act_layer=act_layer,
  179. drop=drop,
  180. )
  181. def forward(self, q, kv):
  182. x1 = self.norm1(q + self.drop_path(self.mixer(q, kv)))
  183. x = self.norm2(x1 + self.drop_path(self.mlp(x1)))
  184. return x
  185. class CPPDHead(nn.Layer):
  186. def __init__(
  187. self,
  188. in_channels,
  189. dim,
  190. out_channels,
  191. num_layer=2,
  192. drop_path_rate=0.1,
  193. max_len=25,
  194. vis_seq=50,
  195. ch=False,
  196. **kwargs,
  197. ):
  198. super(CPPDHead, self).__init__()
  199. self.out_channels = out_channels # none + 26 + 10
  200. self.dim = dim
  201. self.ch = ch
  202. self.max_len = max_len + 1 # max_len + eos
  203. self.char_node_embed = Embeddings(
  204. d_model=dim, vocab=self.out_channels, scale_embedding=True
  205. )
  206. self.pos_node_embed = Embeddings(
  207. d_model=dim, vocab=self.max_len, scale_embedding=True
  208. )
  209. dpr = np.linspace(0, drop_path_rate, num_layer + 1)
  210. self.char_node_decoder = nn.LayerList(
  211. [
  212. DecoderLayer(
  213. dim=dim,
  214. num_heads=dim // 32,
  215. mlp_ratio=4.0,
  216. qkv_bias=True,
  217. drop_path=dpr[i],
  218. )
  219. for i in range(num_layer)
  220. ]
  221. )
  222. self.pos_node_decoder = nn.LayerList(
  223. [
  224. DecoderLayer(
  225. dim=dim,
  226. num_heads=dim // 32,
  227. mlp_ratio=4.0,
  228. qkv_bias=True,
  229. drop_path=dpr[i],
  230. )
  231. for i in range(num_layer)
  232. ]
  233. )
  234. self.edge_decoder = EdgeDecoderLayer(
  235. dim=dim,
  236. num_heads=dim // 32,
  237. mlp_ratio=4.0,
  238. qkv_bias=True,
  239. drop_path=dpr[num_layer : num_layer + 1],
  240. )
  241. self.char_pos_embed = self.create_parameter(
  242. shape=[1, self.max_len, dim], default_initializer=zeros_
  243. )
  244. self.add_parameter("char_pos_embed", self.char_pos_embed)
  245. self.vis_pos_embed = self.create_parameter(
  246. shape=[1, vis_seq, dim], default_initializer=zeros_
  247. )
  248. self.add_parameter("vis_pos_embed", self.vis_pos_embed)
  249. self.char_node_fc1 = nn.Linear(dim, max_len)
  250. self.pos_node_fc1 = nn.Linear(dim, self.max_len)
  251. self.edge_fc = nn.Linear(dim, self.out_channels)
  252. trunc_normal_(self.char_pos_embed)
  253. trunc_normal_(self.vis_pos_embed)
  254. self.apply(self._init_weights)
  255. def _init_weights(self, m):
  256. if isinstance(m, nn.Linear):
  257. trunc_normal_(m.weight)
  258. if isinstance(m, nn.Linear) and m.bias is not None:
  259. zeros_(m.bias)
  260. elif isinstance(m, nn.LayerNorm):
  261. zeros_(m.bias)
  262. ones_(m.weight)
  263. def forward(self, x, targets=None, epoch=0):
  264. if self.training:
  265. return self.forward_train(x, targets, epoch)
  266. else:
  267. return self.forward_test(x)
  268. def forward_test(self, x):
  269. visual_feats = x + self.vis_pos_embed
  270. bs = visual_feats.shape[0]
  271. pos_node_embed = (
  272. self.pos_node_embed(paddle.arange(self.max_len)).unsqueeze(0)
  273. + self.char_pos_embed
  274. )
  275. pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1])
  276. char_vis_node_query = visual_feats
  277. pos_vis_node_query = paddle.concat([pos_node_embed, visual_feats], 1)
  278. for char_decoder_layer, pos_decoder_layer in zip(
  279. self.char_node_decoder, self.pos_node_decoder
  280. ):
  281. char_vis_node_query = char_decoder_layer(
  282. char_vis_node_query, char_vis_node_query
  283. )
  284. pos_vis_node_query = pos_decoder_layer(
  285. pos_vis_node_query, pos_vis_node_query[:, self.max_len :, :]
  286. )
  287. pos_node_query = pos_vis_node_query[:, : self.max_len, :]
  288. char_vis_feats = char_vis_node_query
  289. pos_node_feats = self.edge_decoder(
  290. pos_node_query, char_vis_feats, char_vis_feats
  291. ) # B, 26, dim
  292. edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
  293. edge_logits = F.softmax(edge_feats, -1)
  294. return edge_logits
  295. def forward_train(self, x, targets=None, epoch=0):
  296. visual_feats = x + self.vis_pos_embed
  297. bs = visual_feats.shape[0]
  298. if self.ch:
  299. char_node_embed = self.char_node_embed(targets[-2])
  300. else:
  301. char_node_embed = self.char_node_embed(
  302. paddle.arange(self.out_channels)
  303. ).unsqueeze(0)
  304. char_node_embed = paddle.tile(char_node_embed, [bs, 1, 1])
  305. counting_char_num = char_node_embed.shape[1]
  306. pos_node_embed = (
  307. self.pos_node_embed(paddle.arange(self.max_len)).unsqueeze(0)
  308. + self.char_pos_embed
  309. )
  310. pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1])
  311. node_feats = []
  312. char_vis_node_query = paddle.concat([char_node_embed, visual_feats], 1)
  313. pos_vis_node_query = paddle.concat([pos_node_embed, visual_feats], 1)
  314. for char_decoder_layer, pos_decoder_layer in zip(
  315. self.char_node_decoder, self.pos_node_decoder
  316. ):
  317. char_vis_node_query = char_decoder_layer(
  318. char_vis_node_query, char_vis_node_query[:, counting_char_num:, :]
  319. )
  320. pos_vis_node_query = pos_decoder_layer(
  321. pos_vis_node_query, pos_vis_node_query[:, self.max_len :, :]
  322. )
  323. char_node_query = char_vis_node_query[:, :counting_char_num, :]
  324. pos_node_query = pos_vis_node_query[:, : self.max_len, :]
  325. char_vis_feats = char_vis_node_query[:, counting_char_num:, :]
  326. char_node_feats1 = self.char_node_fc1(char_node_query)
  327. pos_node_feats1 = self.pos_node_fc1(pos_node_query)
  328. diag_mask = (
  329. paddle.eye(pos_node_feats1.shape[1])
  330. .unsqueeze(0)
  331. .tile([pos_node_feats1.shape[0], 1, 1])
  332. )
  333. pos_node_feats1 = (pos_node_feats1 * diag_mask).sum(-1)
  334. node_feats.append(char_node_feats1)
  335. node_feats.append(pos_node_feats1)
  336. pos_node_feats = self.edge_decoder(
  337. pos_node_query, char_vis_feats, char_vis_feats
  338. ) # B, 26, dim
  339. edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37
  340. return node_feats, edge_feats