rec_visionlan_head.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/wangyuxin87/VisionLAN
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. from paddle import ParamAttr
  23. import paddle.nn as nn
  24. import paddle.nn.functional as F
  25. from paddle.nn.initializer import Normal, XavierNormal
  26. import numpy as np
  27. class PositionalEncoding(nn.Layer):
  28. def __init__(self, d_hid, n_position=200):
  29. super(PositionalEncoding, self).__init__()
  30. self.register_buffer(
  31. "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
  32. )
  33. def _get_sinusoid_encoding_table(self, n_position, d_hid):
  34. """Sinusoid position encoding table"""
  35. def get_position_angle_vec(position):
  36. return [
  37. position / np.power(10000, 2 * (hid_j // 2) / d_hid)
  38. for hid_j in range(d_hid)
  39. ]
  40. sinusoid_table = np.array(
  41. [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
  42. )
  43. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  44. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  45. sinusoid_table = paddle.to_tensor(sinusoid_table, dtype="float32")
  46. sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
  47. return sinusoid_table
  48. def forward(self, x):
  49. return x + self.pos_table[:, : x.shape[1]].clone().detach()
  50. class ScaledDotProductAttention(nn.Layer):
  51. "Scaled Dot-Product Attention"
  52. def __init__(self, temperature, attn_dropout=0.1):
  53. super(ScaledDotProductAttention, self).__init__()
  54. self.temperature = temperature
  55. self.dropout = nn.Dropout(attn_dropout)
  56. self.softmax = nn.Softmax(axis=2)
  57. def forward(self, q, k, v, mask=None):
  58. k = paddle.transpose(k, perm=[0, 2, 1])
  59. attn = paddle.bmm(q, k)
  60. attn = attn / self.temperature
  61. if mask is not None:
  62. attn = attn.masked_fill(mask, -1e9)
  63. if mask.dim() == 3:
  64. mask = paddle.unsqueeze(mask, axis=1)
  65. elif mask.dim() == 2:
  66. mask = paddle.unsqueeze(mask, axis=1)
  67. mask = paddle.unsqueeze(mask, axis=1)
  68. repeat_times = [
  69. attn.shape[1] // mask.shape[1],
  70. attn.shape[2] // mask.shape[2],
  71. ]
  72. mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
  73. attn[mask == 0] = -1e9
  74. attn = self.softmax(attn)
  75. attn = self.dropout(attn)
  76. output = paddle.bmm(attn, v)
  77. return output
  78. class MultiHeadAttention(nn.Layer):
  79. "Multi-Head Attention module"
  80. def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
  81. super(MultiHeadAttention, self).__init__()
  82. self.n_head = n_head
  83. self.d_k = d_k
  84. self.d_v = d_v
  85. self.w_qs = nn.Linear(
  86. d_model,
  87. n_head * d_k,
  88. weight_attr=ParamAttr(
  89. initializer=Normal(mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
  90. ),
  91. )
  92. self.w_ks = nn.Linear(
  93. d_model,
  94. n_head * d_k,
  95. weight_attr=ParamAttr(
  96. initializer=Normal(mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
  97. ),
  98. )
  99. self.w_vs = nn.Linear(
  100. d_model,
  101. n_head * d_v,
  102. weight_attr=ParamAttr(
  103. initializer=Normal(mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
  104. ),
  105. )
  106. self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
  107. self.layer_norm = nn.LayerNorm(d_model)
  108. self.fc = nn.Linear(
  109. n_head * d_v, d_model, weight_attr=ParamAttr(initializer=XavierNormal())
  110. )
  111. self.dropout = nn.Dropout(dropout)
  112. def forward(self, q, k, v, mask=None):
  113. d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
  114. sz_b, len_q, _ = q.shape
  115. sz_b, len_k, _ = k.shape
  116. sz_b, len_v, _ = v.shape
  117. residual = q
  118. q = self.w_qs(q)
  119. q = paddle.reshape(q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
  120. k = self.w_ks(k)
  121. k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
  122. v = self.w_vs(v)
  123. v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
  124. q = paddle.transpose(q, perm=[2, 0, 1, 3])
  125. q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
  126. k = paddle.transpose(k, perm=[2, 0, 1, 3])
  127. k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
  128. v = paddle.transpose(v, perm=[2, 0, 1, 3])
  129. v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
  130. mask = (
  131. paddle.tile(mask, [n_head, 1, 1]) if mask is not None else None
  132. ) # (n*b) x .. x ..
  133. output = self.attention(q, k, v, mask=mask)
  134. output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
  135. output = paddle.transpose(output, perm=[1, 2, 0, 3])
  136. output = paddle.reshape(
  137. output, shape=[-1, len_q, n_head * d_v]
  138. ) # b x lq x (n*dv)
  139. output = self.dropout(self.fc(output))
  140. output = self.layer_norm(output + residual)
  141. return output
  142. class PositionwiseFeedForward(nn.Layer):
  143. def __init__(self, d_in, d_hid, dropout=0.1):
  144. super(PositionwiseFeedForward, self).__init__()
  145. self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
  146. self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
  147. self.layer_norm = nn.LayerNorm(d_in)
  148. self.dropout = nn.Dropout(dropout)
  149. def forward(self, x):
  150. residual = x
  151. x = paddle.transpose(x, perm=[0, 2, 1])
  152. x = self.w_2(F.relu(self.w_1(x)))
  153. x = paddle.transpose(x, perm=[0, 2, 1])
  154. x = self.dropout(x)
  155. x = self.layer_norm(x + residual)
  156. return x
  157. class EncoderLayer(nn.Layer):
  158. """Compose with two layers"""
  159. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  160. super(EncoderLayer, self).__init__()
  161. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  162. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  163. def forward(self, enc_input, slf_attn_mask=None):
  164. enc_output = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
  165. enc_output = self.pos_ffn(enc_output)
  166. return enc_output
  167. class Transformer_Encoder(nn.Layer):
  168. def __init__(
  169. self,
  170. n_layers=2,
  171. n_head=8,
  172. d_word_vec=512,
  173. d_k=64,
  174. d_v=64,
  175. d_model=512,
  176. d_inner=2048,
  177. dropout=0.1,
  178. n_position=256,
  179. ):
  180. super(Transformer_Encoder, self).__init__()
  181. self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
  182. self.dropout = nn.Dropout(p=dropout)
  183. self.layer_stack = nn.LayerList(
  184. [
  185. EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  186. for _ in range(n_layers)
  187. ]
  188. )
  189. self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
  190. def forward(self, enc_output, src_mask, return_attns=False):
  191. enc_output = self.dropout(self.position_enc(enc_output)) # position embedding
  192. for enc_layer in self.layer_stack:
  193. enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
  194. enc_output = self.layer_norm(enc_output)
  195. return enc_output
  196. class PP_layer(nn.Layer):
  197. def __init__(self, n_dim=512, N_max_character=25, n_position=256):
  198. super(PP_layer, self).__init__()
  199. self.character_len = N_max_character
  200. self.f0_embedding = nn.Embedding(N_max_character, n_dim)
  201. self.w0 = nn.Linear(N_max_character, n_position)
  202. self.wv = nn.Linear(n_dim, n_dim)
  203. self.we = nn.Linear(n_dim, N_max_character)
  204. self.active = nn.Tanh()
  205. self.softmax = nn.Softmax(axis=2)
  206. def forward(self, enc_output):
  207. # enc_output: b,256,512
  208. reading_order = paddle.arange(self.character_len, dtype="int64")
  209. reading_order = reading_order.unsqueeze(0).expand(
  210. [enc_output.shape[0], self.character_len]
  211. ) # (S,) -> (B, S)
  212. reading_order = self.f0_embedding(reading_order) # b,25,512
  213. # calculate attention
  214. reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
  215. t = self.w0(reading_order) # b,512,256
  216. t = self.active(
  217. paddle.transpose(t, perm=[0, 2, 1]) + self.wv(enc_output)
  218. ) # b,256,512
  219. t = self.we(t) # b,256,25
  220. t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
  221. g_output = paddle.bmm(t, enc_output) # b,25,512
  222. return g_output
  223. class Prediction(nn.Layer):
  224. def __init__(self, n_dim=512, n_position=256, N_max_character=25, n_class=37):
  225. super(Prediction, self).__init__()
  226. self.pp = PP_layer(
  227. n_dim=n_dim, N_max_character=N_max_character, n_position=n_position
  228. )
  229. self.pp_share = PP_layer(
  230. n_dim=n_dim, N_max_character=N_max_character, n_position=n_position
  231. )
  232. self.w_vrm = nn.Linear(n_dim, n_class) # output layer
  233. self.w_share = nn.Linear(n_dim, n_class) # output layer
  234. self.nclass = n_class
  235. def forward(self, cnn_feature, f_res, f_sub, train_mode=False, use_mlm=True):
  236. if train_mode:
  237. if not use_mlm:
  238. g_output = self.pp(cnn_feature) # b,25,512
  239. g_output = self.w_vrm(g_output)
  240. f_res = 0
  241. f_sub = 0
  242. return g_output, f_res, f_sub
  243. g_output = self.pp(cnn_feature) # b,25,512
  244. f_res = self.pp_share(f_res)
  245. f_sub = self.pp_share(f_sub)
  246. g_output = self.w_vrm(g_output)
  247. f_res = self.w_share(f_res)
  248. f_sub = self.w_share(f_sub)
  249. return g_output, f_res, f_sub
  250. else:
  251. g_output = self.pp(cnn_feature) # b,25,512
  252. g_output = self.w_vrm(g_output)
  253. return g_output
  254. class MLM(nn.Layer):
  255. "Architecture of MLM"
  256. def __init__(self, n_dim=512, n_position=256, max_text_length=25):
  257. super(MLM, self).__init__()
  258. self.MLM_SequenceModeling_mask = Transformer_Encoder(
  259. n_layers=2, n_position=n_position
  260. )
  261. self.MLM_SequenceModeling_WCL = Transformer_Encoder(
  262. n_layers=1, n_position=n_position
  263. )
  264. self.pos_embedding = nn.Embedding(max_text_length, n_dim)
  265. self.w0_linear = nn.Linear(1, n_position)
  266. self.wv = nn.Linear(n_dim, n_dim)
  267. self.active = nn.Tanh()
  268. self.we = nn.Linear(n_dim, 1)
  269. self.sigmoid = nn.Sigmoid()
  270. def forward(self, x, label_pos):
  271. # transformer unit for generating mask_c
  272. feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
  273. # position embedding layer
  274. label_pos = paddle.to_tensor(label_pos, dtype="int64")
  275. pos_emb = self.pos_embedding(label_pos)
  276. pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
  277. pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
  278. # fusion position embedding with features V & generate mask_c
  279. att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
  280. att_map_sub = self.we(att_map_sub) # b,256,1
  281. att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
  282. att_map_sub = self.sigmoid(att_map_sub) # b,1,256
  283. # WCL
  284. ## generate inputs for WCL
  285. att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
  286. f_res = x * (1 - att_map_sub) # second path with remaining string
  287. f_sub = x * att_map_sub # first path with occluded character
  288. ## transformer units in WCL
  289. f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
  290. f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
  291. return f_res, f_sub, att_map_sub
  292. def trans_1d_2d(x):
  293. b, w_h, c = x.shape # b, 256, 512
  294. x = paddle.transpose(x, perm=[0, 2, 1])
  295. x = paddle.reshape(x, [-1, c, 32, 8])
  296. x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
  297. return x
  298. class MLM_VRM(nn.Layer):
  299. """
  300. MLM+VRM, MLM is only used in training.
  301. ratio controls the occluded number in a batch.
  302. The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
  303. x: input image
  304. label_pos: character index
  305. training_step: LF or LA process
  306. output
  307. text_pre: prediction of VRM
  308. test_rem: prediction of remaining string in MLM
  309. text_mas: prediction of occluded character in MLM
  310. mask_c_show: visualization of Mask_c
  311. """
  312. def __init__(
  313. self, n_layers=3, n_position=256, n_dim=512, max_text_length=25, nclass=37
  314. ):
  315. super(MLM_VRM, self).__init__()
  316. self.MLM = MLM(
  317. n_dim=n_dim, n_position=n_position, max_text_length=max_text_length
  318. )
  319. self.SequenceModeling = Transformer_Encoder(
  320. n_layers=n_layers, n_position=n_position
  321. )
  322. self.Prediction = Prediction(
  323. n_dim=n_dim,
  324. n_position=n_position,
  325. N_max_character=max_text_length
  326. + 1, # N_max_character = 1 eos + 25 characters
  327. n_class=nclass,
  328. )
  329. self.nclass = nclass
  330. self.max_text_length = max_text_length
  331. def forward(self, x, label_pos, training_step, train_mode=False):
  332. b, c, h, w = x.shape
  333. nT = self.max_text_length
  334. x = paddle.transpose(x, perm=[0, 1, 3, 2])
  335. x = paddle.reshape(x, [-1, c, h * w])
  336. x = paddle.transpose(x, perm=[0, 2, 1])
  337. if train_mode:
  338. if training_step == "LF_1":
  339. f_res = 0
  340. f_sub = 0
  341. x = self.SequenceModeling(x, src_mask=None)
  342. text_pre, test_rem, text_mas = self.Prediction(
  343. x, f_res, f_sub, train_mode=True, use_mlm=False
  344. )
  345. return text_pre, text_pre, text_pre, text_pre
  346. elif training_step == "LF_2":
  347. # MLM
  348. f_res, f_sub, mask_c = self.MLM(x, label_pos)
  349. x = self.SequenceModeling(x, src_mask=None)
  350. text_pre, test_rem, text_mas = self.Prediction(
  351. x, f_res, f_sub, train_mode=True
  352. )
  353. mask_c_show = trans_1d_2d(mask_c)
  354. return text_pre, test_rem, text_mas, mask_c_show
  355. elif training_step == "LA":
  356. # MLM
  357. f_res, f_sub, mask_c = self.MLM(x, label_pos)
  358. ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
  359. ## ratio controls the occluded number in a batch
  360. character_mask = paddle.zeros_like(mask_c)
  361. ratio = b // 2
  362. if ratio >= 1:
  363. with paddle.no_grad():
  364. character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
  365. else:
  366. character_mask = mask_c
  367. x = x * (1 - character_mask)
  368. # VRM
  369. ## transformer unit for VRM
  370. x = self.SequenceModeling(x, src_mask=None)
  371. ## prediction layer for MLM and VSR
  372. text_pre, test_rem, text_mas = self.Prediction(
  373. x, f_res, f_sub, train_mode=True
  374. )
  375. mask_c_show = trans_1d_2d(mask_c)
  376. return text_pre, test_rem, text_mas, mask_c_show
  377. else:
  378. raise NotImplementedError
  379. else: # VRM is only used in the testing stage
  380. f_res = 0
  381. f_sub = 0
  382. contextual_feature = self.SequenceModeling(x, src_mask=None)
  383. text_pre = self.Prediction(
  384. contextual_feature, f_res, f_sub, train_mode=False, use_mlm=False
  385. )
  386. text_pre = paddle.transpose(text_pre, perm=[1, 0, 2]) # (26, b, 37))
  387. return text_pre, x
  388. class VLHead(nn.Layer):
  389. """
  390. Architecture of VisionLAN
  391. """
  392. def __init__(
  393. self,
  394. in_channels,
  395. out_channels=36,
  396. n_layers=3,
  397. n_position=256,
  398. n_dim=512,
  399. max_text_length=25,
  400. training_step="LA",
  401. ):
  402. super(VLHead, self).__init__()
  403. self.MLM_VRM = MLM_VRM(
  404. n_layers=n_layers,
  405. n_position=n_position,
  406. n_dim=n_dim,
  407. max_text_length=max_text_length,
  408. nclass=out_channels + 1,
  409. )
  410. self.training_step = training_step
  411. def forward(self, feat, targets=None):
  412. if self.training:
  413. label_pos = targets[-2]
  414. text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
  415. feat, label_pos, self.training_step, train_mode=True
  416. )
  417. return text_pre, test_rem, text_mas, mask_map
  418. else:
  419. text_pre, x = self.MLM_VRM(
  420. feat, targets, self.training_step, train_mode=False
  421. )
  422. return text_pre, x