rec_aster_head.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import sys
  22. import paddle
  23. from paddle import nn
  24. from paddle.nn import functional as F
  25. class AsterHead(nn.Layer):
  26. def __init__(
  27. self,
  28. in_channels,
  29. out_channels,
  30. sDim,
  31. attDim,
  32. max_len_labels,
  33. time_step=25,
  34. beam_width=5,
  35. **kwargs,
  36. ):
  37. super(AsterHead, self).__init__()
  38. self.num_classes = out_channels
  39. self.in_planes = in_channels
  40. self.sDim = sDim
  41. self.attDim = attDim
  42. self.max_len_labels = max_len_labels
  43. self.decoder = AttentionRecognitionHead(
  44. in_channels, out_channels, sDim, attDim, max_len_labels
  45. )
  46. self.time_step = time_step
  47. self.embedder = Embedding(self.time_step, in_channels)
  48. self.beam_width = beam_width
  49. self.eos = self.num_classes - 3
  50. def forward(self, x, targets=None, embed=None):
  51. return_dict = {}
  52. embedding_vectors = self.embedder(x)
  53. if self.training:
  54. rec_targets, rec_lengths, _ = targets
  55. rec_pred = self.decoder([x, rec_targets, rec_lengths], embedding_vectors)
  56. return_dict["rec_pred"] = rec_pred
  57. return_dict["embedding_vectors"] = embedding_vectors
  58. else:
  59. rec_pred, rec_pred_scores = self.decoder.beam_search(
  60. x, self.beam_width, self.eos, embedding_vectors
  61. )
  62. return_dict["rec_pred"] = rec_pred
  63. return_dict["rec_pred_scores"] = rec_pred_scores
  64. return_dict["embedding_vectors"] = embedding_vectors
  65. return return_dict
  66. class Embedding(nn.Layer):
  67. def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
  68. super(Embedding, self).__init__()
  69. self.in_timestep = in_timestep
  70. self.in_planes = in_planes
  71. self.embed_dim = embed_dim
  72. self.mid_dim = mid_dim
  73. self.eEmbed = nn.Linear(
  74. in_timestep * in_planes, self.embed_dim
  75. ) # Embed encoder output to a word-embedding like
  76. def forward(self, x):
  77. x = paddle.reshape(x, [x.shape[0], -1])
  78. x = self.eEmbed(x)
  79. return x
  80. class AttentionRecognitionHead(nn.Layer):
  81. """
  82. input: [b x 16 x 64 x in_planes]
  83. output: probability sequence: [b x T x num_classes]
  84. """
  85. def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
  86. super(AttentionRecognitionHead, self).__init__()
  87. self.num_classes = (
  88. out_channels # this is the output classes. So it includes the <EOS>.
  89. )
  90. self.in_planes = in_channels
  91. self.sDim = sDim
  92. self.attDim = attDim
  93. self.max_len_labels = max_len_labels
  94. self.decoder = DecoderUnit(
  95. sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim
  96. )
  97. def forward(self, x, embed):
  98. x, targets, lengths = x
  99. batch_size = x.shape[0]
  100. # Decoder
  101. state = self.decoder.get_initial_state(embed)
  102. outputs = []
  103. for i in range(max(lengths)):
  104. if i == 0:
  105. y_prev = paddle.full(shape=[batch_size], fill_value=self.num_classes)
  106. else:
  107. y_prev = targets[:, i - 1]
  108. output, state = self.decoder(x, state, y_prev)
  109. outputs.append(output)
  110. outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
  111. return outputs
  112. # inference stage.
  113. def sample(self, x):
  114. x, _, _ = x
  115. batch_size = x.size(0)
  116. # Decoder
  117. state = paddle.zeros([1, batch_size, self.sDim])
  118. predicted_ids, predicted_scores, predicted = [], [], None
  119. for i in range(self.max_len_labels):
  120. if i == 0:
  121. y_prev = paddle.full(shape=[batch_size], fill_value=self.num_classes)
  122. else:
  123. y_prev = predicted
  124. output, state = self.decoder(x, state, y_prev)
  125. output = F.softmax(output, axis=1)
  126. score, predicted = output.max(1)
  127. predicted_ids.append(predicted.unsqueeze(1))
  128. predicted_scores.append(score.unsqueeze(1))
  129. predicted_ids = paddle.concat([predicted_ids, 1])
  130. predicted_scores = paddle.concat([predicted_scores, 1])
  131. # return predicted_ids.squeeze(), predicted_scores.squeeze()
  132. return predicted_ids, predicted_scores
  133. def beam_search(self, x, beam_width, eos, embed):
  134. def _inflate(tensor, times, dim):
  135. repeat_dims = [1] * tensor.dim()
  136. repeat_dims[dim] = times
  137. output = paddle.tile(tensor, repeat_dims)
  138. return output
  139. # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
  140. batch_size, l, d = x.shape
  141. x = paddle.tile(
  142. paddle.transpose(x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]
  143. )
  144. inflated_encoder_feats = paddle.reshape(
  145. paddle.transpose(x, perm=[1, 0, 2, 3]), [-1, l, d]
  146. )
  147. # Initialize the decoder
  148. state = self.decoder.get_initial_state(embed, tile_times=beam_width)
  149. pos_index = paddle.reshape(
  150. paddle.arange(batch_size) * beam_width, shape=[-1, 1]
  151. )
  152. # Initialize the scores
  153. sequence_scores = paddle.full(
  154. shape=[batch_size * beam_width, 1], fill_value=-float("Inf")
  155. )
  156. index = [i * beam_width for i in range(0, batch_size)]
  157. sequence_scores[index] = 0.0
  158. # Initialize the input vector
  159. y_prev = paddle.full(
  160. shape=[batch_size * beam_width], fill_value=self.num_classes
  161. )
  162. # Store decisions for backtracking
  163. stored_scores = list()
  164. stored_predecessors = list()
  165. stored_emitted_symbols = list()
  166. for i in range(self.max_len_labels):
  167. output, state = self.decoder(inflated_encoder_feats, state, y_prev)
  168. state = paddle.unsqueeze(state, axis=0)
  169. log_softmax_output = paddle.nn.functional.log_softmax(output, axis=1)
  170. sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
  171. sequence_scores += log_softmax_output
  172. scores, candidates = paddle.topk(
  173. paddle.reshape(sequence_scores, [batch_size, -1]), beam_width, axis=1
  174. )
  175. # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
  176. y_prev = paddle.reshape(
  177. candidates % self.num_classes, shape=[batch_size * beam_width]
  178. )
  179. sequence_scores = paddle.reshape(scores, shape=[batch_size * beam_width, 1])
  180. # Update fields for next timestep
  181. pos_index = paddle.expand_as(pos_index, candidates)
  182. predecessors = paddle.cast(
  183. candidates / self.num_classes + pos_index, dtype="int64"
  184. )
  185. predecessors = paddle.reshape(
  186. predecessors, shape=[batch_size * beam_width, 1]
  187. )
  188. state = paddle.index_select(state, index=predecessors.squeeze(), axis=1)
  189. # Update sequence scores and erase scores for <eos> symbol so that they aren't expanded
  190. stored_scores.append(sequence_scores.clone())
  191. y_prev = paddle.reshape(y_prev, shape=[-1, 1])
  192. eos_prev = paddle.full_like(y_prev, fill_value=eos)
  193. mask = eos_prev == y_prev
  194. mask = paddle.nonzero(mask)
  195. if mask.dim() > 0:
  196. sequence_scores = sequence_scores.numpy()
  197. mask = mask.numpy()
  198. sequence_scores[mask] = -float("inf")
  199. sequence_scores = paddle.to_tensor(sequence_scores)
  200. # Cache results for backtracking
  201. stored_predecessors.append(predecessors)
  202. y_prev = paddle.squeeze(y_prev)
  203. stored_emitted_symbols.append(y_prev)
  204. # Do backtracking to return the optimal values
  205. # ====== backtrak ======#
  206. # Initialize return variables given different types
  207. p = list()
  208. l = [
  209. [self.max_len_labels] * beam_width for _ in range(batch_size)
  210. ] # Placeholder for lengths of top-k sequences
  211. # the last step output of the beams are not sorted
  212. # thus they are sorted here
  213. sorted_score, sorted_idx = paddle.topk(
  214. paddle.reshape(stored_scores[-1], shape=[batch_size, beam_width]),
  215. beam_width,
  216. )
  217. # initialize the sequence scores with the sorted last step beam scores
  218. s = sorted_score.clone()
  219. batch_eos_found = [0] * batch_size # the number of EOS found
  220. # in the backward loop below for each batch
  221. t = self.max_len_labels - 1
  222. # initialize the back pointer with the sorted order of the last step beams.
  223. # add pos_index for indexing variable with b*k as the first dimension.
  224. t_predecessors = paddle.reshape(
  225. sorted_idx + pos_index.expand_as(sorted_idx),
  226. shape=[batch_size * beam_width],
  227. )
  228. while t >= 0:
  229. # Re-order the variables with the back pointer
  230. current_symbol = paddle.index_select(
  231. stored_emitted_symbols[t], index=t_predecessors, axis=0
  232. )
  233. t_predecessors = paddle.index_select(
  234. stored_predecessors[t].squeeze(), index=t_predecessors, axis=0
  235. )
  236. eos_indices = stored_emitted_symbols[t] == eos
  237. eos_indices = paddle.nonzero(eos_indices)
  238. if eos_indices.dim() > 0:
  239. for i in range(eos_indices.shape[0] - 1, -1, -1):
  240. # Indices of the EOS symbol for both variables
  241. # with b*k as the first dimension, and b, k for
  242. # the first two dimensions
  243. idx = eos_indices[i]
  244. b_idx = int(idx[0] / beam_width)
  245. # The indices of the replacing position
  246. # according to the replacement strategy noted above
  247. res_k_idx = beam_width - (batch_eos_found[b_idx] % beam_width) - 1
  248. batch_eos_found[b_idx] += 1
  249. res_idx = b_idx * beam_width + res_k_idx
  250. # Replace the old information in return variables
  251. # with the new ended sequence information
  252. t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
  253. current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
  254. s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
  255. l[b_idx][res_k_idx] = t + 1
  256. # record the back tracked results
  257. p.append(current_symbol)
  258. t -= 1
  259. # Sort and re-order again as the added ended sequences may change
  260. # the order (very unlikely)
  261. s, re_sorted_idx = s.topk(beam_width)
  262. for b_idx in range(batch_size):
  263. l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]]
  264. re_sorted_idx = paddle.reshape(
  265. re_sorted_idx + pos_index.expand_as(re_sorted_idx),
  266. [batch_size * beam_width],
  267. )
  268. # Reverse the sequences and re-order at the same time
  269. # It is reversed because the backtracking happens in reverse time order
  270. p = [
  271. paddle.reshape(
  272. paddle.index_select(step, re_sorted_idx, 0),
  273. shape=[batch_size, beam_width, -1],
  274. )
  275. for step in reversed(p)
  276. ]
  277. p = paddle.concat(p, -1)[:, 0, :]
  278. return p, paddle.ones_like(p)
  279. class AttentionUnit(nn.Layer):
  280. def __init__(self, sDim, xDim, attDim):
  281. super(AttentionUnit, self).__init__()
  282. self.sDim = sDim
  283. self.xDim = xDim
  284. self.attDim = attDim
  285. self.sEmbed = nn.Linear(sDim, attDim)
  286. self.xEmbed = nn.Linear(xDim, attDim)
  287. self.wEmbed = nn.Linear(attDim, 1)
  288. def forward(self, x, sPrev):
  289. batch_size, T, _ = x.shape # [b x T x xDim]
  290. x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
  291. xProj = self.xEmbed(x) # [(b x T) x attDim]
  292. xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
  293. sPrev = sPrev.squeeze(0)
  294. sProj = self.sEmbed(sPrev) # [b x attDim]
  295. sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
  296. sProj = paddle.expand(sProj, [batch_size, T, self.attDim]) # [b x T x attDim]
  297. sumTanh = paddle.tanh(sProj + xProj)
  298. sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
  299. vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
  300. vProj = paddle.reshape(vProj, [batch_size, T])
  301. alpha = F.softmax(
  302. vProj, axis=1
  303. ) # attention weights for each sample in the minibatch
  304. return alpha
  305. class DecoderUnit(nn.Layer):
  306. def __init__(self, sDim, xDim, yDim, attDim):
  307. super(DecoderUnit, self).__init__()
  308. self.sDim = sDim
  309. self.xDim = xDim
  310. self.yDim = yDim
  311. self.attDim = attDim
  312. self.emdDim = attDim
  313. self.attention_unit = AttentionUnit(sDim, xDim, attDim)
  314. self.tgt_embedding = nn.Embedding(
  315. yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(std=0.01)
  316. ) # the last is used for <BOS>
  317. self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
  318. self.fc = nn.Linear(
  319. sDim,
  320. yDim,
  321. weight_attr=nn.initializer.Normal(std=0.01),
  322. bias_attr=nn.initializer.Constant(value=0),
  323. )
  324. self.embed_fc = nn.Linear(300, self.sDim)
  325. def get_initial_state(self, embed, tile_times=1):
  326. assert embed.shape[1] == 300
  327. state = self.embed_fc(embed) # N * sDim
  328. if tile_times != 1:
  329. state = state.unsqueeze(1)
  330. trans_state = paddle.transpose(state, perm=[1, 0, 2])
  331. state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
  332. trans_state = paddle.transpose(state, perm=[1, 0, 2])
  333. state = paddle.reshape(trans_state, shape=[-1, self.sDim])
  334. state = state.unsqueeze(0) # 1 * N * sDim
  335. return state
  336. def forward(self, x, sPrev, yPrev):
  337. # x: feature sequence from the image decoder.
  338. batch_size, T, _ = x.shape
  339. alpha = self.attention_unit(x, sPrev)
  340. context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
  341. yPrev = paddle.cast(yPrev, dtype="int64")
  342. yProj = self.tgt_embedding(yPrev)
  343. concat_context = paddle.concat([yProj, context], 1)
  344. concat_context = paddle.squeeze(concat_context, 1)
  345. sPrev = paddle.squeeze(sPrev, 0)
  346. output, state = self.gru(concat_context, sPrev)
  347. output = paddle.squeeze(output, axis=1)
  348. output = self.fc(output)
  349. return output, state