table_att_head.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. import paddle.nn.functional as F
  22. import numpy as np
  23. from .rec_att_head import AttentionGRUCell
  24. from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, Mlp
  25. def get_para_bias_attr(l2_decay, k):
  26. if l2_decay > 0:
  27. regularizer = paddle.regularizer.L2Decay(l2_decay)
  28. stdv = 1.0 / math.sqrt(k * 1.0)
  29. initializer = nn.initializer.Uniform(-stdv, stdv)
  30. else:
  31. regularizer = None
  32. initializer = None
  33. weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
  34. bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
  35. return [weight_attr, bias_attr]
  36. class TableAttentionHead(nn.Layer):
  37. def __init__(
  38. self,
  39. in_channels,
  40. hidden_size,
  41. in_max_len=488,
  42. max_text_length=800,
  43. out_channels=30,
  44. loc_reg_num=4,
  45. **kwargs,
  46. ):
  47. super(TableAttentionHead, self).__init__()
  48. self.input_size = in_channels[-1]
  49. self.hidden_size = hidden_size
  50. self.out_channels = out_channels
  51. self.max_text_length = max_text_length
  52. self.structure_attention_cell = AttentionGRUCell(
  53. self.input_size, hidden_size, self.out_channels, use_gru=False
  54. )
  55. self.structure_generator = nn.Linear(hidden_size, self.out_channels)
  56. self.in_max_len = in_max_len
  57. if self.in_max_len == 640:
  58. self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
  59. elif self.in_max_len == 800:
  60. self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
  61. else:
  62. self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
  63. self.loc_generator = nn.Linear(self.input_size + hidden_size, loc_reg_num)
  64. def _char_to_onehot(self, input_char, onehot_dim):
  65. input_ont_hot = F.one_hot(input_char, onehot_dim)
  66. return input_ont_hot
  67. def forward(self, inputs, targets=None):
  68. # if and else branch are both needed when you want to assign a variable
  69. # if you modify the var in just one branch, then the modification will not work.
  70. fea = inputs[-1]
  71. last_shape = int(np.prod(fea.shape[2:])) # gry added
  72. fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
  73. fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  74. batch_size = fea.shape[0]
  75. hidden = paddle.zeros((batch_size, self.hidden_size))
  76. output_hiddens = paddle.zeros(
  77. (batch_size, self.max_text_length + 1, self.hidden_size)
  78. )
  79. if self.training and targets is not None:
  80. structure = targets[0]
  81. for i in range(self.max_text_length + 1):
  82. elem_onehots = self._char_to_onehot(
  83. structure[:, i], onehot_dim=self.out_channels
  84. )
  85. (outputs, hidden), alpha = self.structure_attention_cell(
  86. hidden, fea, elem_onehots
  87. )
  88. output_hiddens[:, i, :] = outputs
  89. structure_probs = self.structure_generator(output_hiddens)
  90. loc_fea = fea.transpose([0, 2, 1])
  91. loc_fea = self.loc_fea_trans(loc_fea)
  92. loc_fea = loc_fea.transpose([0, 2, 1])
  93. loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
  94. loc_preds = self.loc_generator(loc_concat)
  95. loc_preds = F.sigmoid(loc_preds)
  96. else:
  97. temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
  98. structure_probs = None
  99. loc_preds = None
  100. elem_onehots = None
  101. outputs = None
  102. alpha = None
  103. max_text_length = paddle.to_tensor(self.max_text_length)
  104. for i in range(max_text_length + 1):
  105. elem_onehots = self._char_to_onehot(
  106. temp_elem, onehot_dim=self.out_channels
  107. )
  108. (outputs, hidden), alpha = self.structure_attention_cell(
  109. hidden, fea, elem_onehots
  110. )
  111. output_hiddens[:, i, :] = outputs
  112. structure_probs_step = self.structure_generator(outputs)
  113. temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
  114. structure_probs = self.structure_generator(output_hiddens)
  115. structure_probs = F.softmax(structure_probs)
  116. loc_fea = fea.transpose([0, 2, 1])
  117. loc_fea = self.loc_fea_trans(loc_fea)
  118. loc_fea = loc_fea.transpose([0, 2, 1])
  119. loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
  120. loc_preds = self.loc_generator(loc_concat)
  121. loc_preds = F.sigmoid(loc_preds)
  122. return {"structure_probs": structure_probs, "loc_preds": loc_preds}
  123. class HWAttention(nn.Layer):
  124. def __init__(
  125. self,
  126. head_dim=32,
  127. qk_scale=None,
  128. attn_drop=0.0,
  129. ):
  130. super().__init__()
  131. self.head_dim = head_dim
  132. self.scale = qk_scale or self.head_dim**-0.5
  133. self.attn_drop = nn.Dropout(attn_drop)
  134. def forward(self, x):
  135. B, N, C = x.shape
  136. C = C // 3
  137. qkv = x.reshape([B, N, 3, C // self.head_dim, self.head_dim]).transpose(
  138. [2, 0, 3, 1, 4]
  139. )
  140. q, k, v = qkv.unbind(0)
  141. attn = q @ k.transpose([0, 1, 3, 2]) * self.scale
  142. attn = F.softmax(attn, -1)
  143. attn = self.attn_drop(attn)
  144. x = attn @ v
  145. x = x.transpose([0, 2, 1]).reshape([B, N, C])
  146. return x
  147. def img2windows(img, H_sp, W_sp):
  148. """
  149. img: B C H W
  150. """
  151. B, H, W, C = img.shape
  152. img_reshape = img.reshape([B, H // H_sp, H_sp, W // W_sp, W_sp, C])
  153. img_perm = img_reshape.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H_sp * W_sp, C])
  154. return img_perm
  155. def windows2img(img_splits_hw, H_sp, W_sp, H, W):
  156. """
  157. img_splits_hw: B' H W C
  158. """
  159. B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
  160. img = img_splits_hw.reshape([B, H // H_sp, W // W_sp, H_sp, W_sp, -1])
  161. img = img.transpose([0, 1, 3, 2, 4, 5]).flatten(1, 4)
  162. return img
  163. class Block(nn.Layer):
  164. def __init__(
  165. self,
  166. dim,
  167. num_heads,
  168. split_h=4,
  169. split_w=4,
  170. h_num_heads=None,
  171. w_num_heads=None,
  172. mlp_ratio=4.0,
  173. qkv_bias=False,
  174. qk_scale=None,
  175. drop=0.0,
  176. attn_drop=0.0,
  177. drop_path=0.0,
  178. act_layer=nn.GELU,
  179. norm_layer=nn.LayerNorm,
  180. eps=1e-6,
  181. ):
  182. super().__init__()
  183. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  184. self.proj = nn.Linear(dim, dim)
  185. self.split_h = split_h
  186. self.split_w = split_w
  187. mlp_hidden_dim = int(dim * mlp_ratio)
  188. self.norm1 = norm_layer(dim, epsilon=eps)
  189. self.h_num_heads = h_num_heads if h_num_heads is not None else num_heads // 2
  190. self.w_num_heads = w_num_heads if w_num_heads is not None else num_heads // 2
  191. self.head_dim = dim // num_heads
  192. self.mixer = HWAttention(
  193. head_dim=dim // num_heads,
  194. qk_scale=qk_scale,
  195. attn_drop=attn_drop,
  196. )
  197. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  198. self.norm2 = norm_layer(dim, epsilon=eps)
  199. self.mlp = Mlp(
  200. in_features=dim,
  201. hidden_features=mlp_hidden_dim,
  202. act_layer=act_layer,
  203. drop=drop,
  204. )
  205. def forward(self, x):
  206. B, C, H, W = x.shape
  207. x = x.flatten(2).transpose([0, 2, 1])
  208. qkv = self.qkv(x).reshape([B, H, W, 3 * C])
  209. x1 = qkv[:, :, :, : 3 * self.h_num_heads * self.head_dim] # b, h, w, 3ch
  210. x2 = qkv[:, :, :, 3 * self.h_num_heads * self.head_dim :] # b, h, w, 3cw
  211. x1 = self.mixer(img2windows(x1, self.split_h, W)) # b*splith, W, 3ch
  212. x2 = self.mixer(img2windows(x2, H, self.split_w)) # b*splitw, h, 3ch
  213. x1 = windows2img(x1, self.split_h, W, H, W)
  214. x2 = windows2img(x2, H, self.split_w, H, W)
  215. attened_x = paddle.concat([x1, x2], 2)
  216. attened_x = self.proj(attened_x)
  217. x = self.norm1(x + self.drop_path(attened_x))
  218. x = self.norm2(x + self.drop_path(self.mlp(x)))
  219. x = x.transpose([0, 2, 1]).reshape([-1, C, H, W])
  220. return x
  221. class SLAHead(nn.Layer):
  222. def __init__(
  223. self,
  224. in_channels,
  225. hidden_size,
  226. out_channels=30,
  227. max_text_length=500,
  228. loc_reg_num=4,
  229. fc_decay=0.0,
  230. use_attn=False,
  231. **kwargs,
  232. ):
  233. """
  234. @param in_channels: input shape
  235. @param hidden_size: hidden_size for RNN and Embedding
  236. @param out_channels: num_classes to rec
  237. @param max_text_length: max text pred
  238. """
  239. super().__init__()
  240. if isinstance(in_channels, int):
  241. self.is_next = True
  242. in_channels = 512
  243. else:
  244. self.is_next = False
  245. in_channels = in_channels[-1]
  246. self.hidden_size = hidden_size
  247. self.max_text_length = max_text_length
  248. self.emb = self._char_to_onehot
  249. self.num_embeddings = out_channels
  250. self.loc_reg_num = loc_reg_num
  251. self.eos = self.num_embeddings - 1
  252. # structure
  253. self.structure_attention_cell = AttentionGRUCell(
  254. in_channels, hidden_size, self.num_embeddings
  255. )
  256. weight_attr, bias_attr = get_para_bias_attr(l2_decay=fc_decay, k=hidden_size)
  257. weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
  258. l2_decay=fc_decay, k=hidden_size
  259. )
  260. weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
  261. l2_decay=fc_decay, k=hidden_size
  262. )
  263. self.structure_generator = nn.Sequential(
  264. nn.Linear(
  265. self.hidden_size,
  266. self.hidden_size,
  267. weight_attr=weight_attr1_2,
  268. bias_attr=bias_attr1_2,
  269. ),
  270. nn.Linear(
  271. hidden_size, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
  272. ),
  273. )
  274. dpr = np.linspace(0, 0.1, 2)
  275. self.use_attn = use_attn
  276. if use_attn:
  277. layer_list = [
  278. Block(
  279. in_channels,
  280. num_heads=2,
  281. mlp_ratio=4.0,
  282. qkv_bias=True,
  283. drop_path=dpr[i],
  284. )
  285. for i in range(2)
  286. ]
  287. self.cross_atten = nn.Sequential(*layer_list)
  288. # loc
  289. weight_attr1, bias_attr1 = get_para_bias_attr(
  290. l2_decay=fc_decay, k=self.hidden_size
  291. )
  292. weight_attr2, bias_attr2 = get_para_bias_attr(
  293. l2_decay=fc_decay, k=self.hidden_size
  294. )
  295. self.loc_generator = nn.Sequential(
  296. nn.Linear(
  297. self.hidden_size,
  298. self.hidden_size,
  299. weight_attr=weight_attr1,
  300. bias_attr=bias_attr1,
  301. ),
  302. nn.Linear(
  303. self.hidden_size,
  304. loc_reg_num,
  305. weight_attr=weight_attr2,
  306. bias_attr=bias_attr2,
  307. ),
  308. nn.Sigmoid(),
  309. )
  310. def forward(self, inputs, targets=None):
  311. if self.is_next == True:
  312. fea = inputs
  313. batch_size = fea.shape[0]
  314. else:
  315. fea = inputs[-1]
  316. batch_size = fea.shape[0]
  317. if self.use_attn:
  318. fea = fea + self.cross_atten(fea)
  319. # reshape
  320. fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
  321. fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  322. hidden = paddle.zeros((batch_size, self.hidden_size))
  323. structure_preds = paddle.zeros(
  324. (batch_size, self.max_text_length + 1, self.num_embeddings)
  325. )
  326. loc_preds = paddle.zeros(
  327. (batch_size, self.max_text_length + 1, self.loc_reg_num)
  328. )
  329. structure_preds.stop_gradient = True
  330. loc_preds.stop_gradient = True
  331. if self.training and targets is not None:
  332. structure = targets[0]
  333. max_len = targets[-2].max().astype("int32")
  334. for i in range(max_len + 1):
  335. hidden, structure_step, loc_step = self._decode(
  336. structure[:, i], fea, hidden
  337. )
  338. structure_preds[:, i, :] = structure_step
  339. loc_preds[:, i, :] = loc_step
  340. structure_preds = structure_preds[:, : max_len + 1]
  341. loc_preds = loc_preds[:, : max_len + 1]
  342. else:
  343. structure_ids = paddle.zeros(
  344. (batch_size, self.max_text_length + 1), dtype="int32"
  345. )
  346. pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
  347. max_text_length = paddle.to_tensor(self.max_text_length)
  348. for i in range(max_text_length + 1):
  349. hidden, structure_step, loc_step = self._decode(pre_chars, fea, hidden)
  350. pre_chars = structure_step.argmax(axis=1, dtype="int32")
  351. structure_preds[:, i, :] = structure_step
  352. loc_preds[:, i, :] = loc_step
  353. structure_ids[:, i] = pre_chars
  354. if (structure_ids == self.eos).any(-1).all():
  355. break
  356. if not self.training:
  357. structure_preds = F.softmax(structure_preds[:, : i + 1])
  358. loc_preds = loc_preds[:, : i + 1]
  359. return {"structure_probs": structure_preds, "loc_preds": loc_preds}
  360. def _decode(self, pre_chars, features, hidden):
  361. """
  362. Predict table label and coordinates for each step
  363. @param pre_chars: Table label in previous step
  364. @param features:
  365. @param hidden: hidden status in previous step
  366. @return:
  367. """
  368. emb_feature = self.emb(pre_chars)
  369. # output shape is b * self.hidden_size
  370. (output, hidden), alpha = self.structure_attention_cell(
  371. hidden, features, emb_feature
  372. )
  373. # structure
  374. structure_step = self.structure_generator(output)
  375. # loc
  376. loc_step = self.loc_generator(output)
  377. return hidden, structure_step, loc_step
  378. def _char_to_onehot(self, input_char):
  379. input_ont_hot = F.one_hot(input_char, self.num_embeddings)
  380. return input_ont_hot