rec_can_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/LBH1024/CAN/models/can.py
  17. https://github.com/LBH1024/CAN/models/counting.py
  18. https://github.com/LBH1024/CAN/models/decoder.py
  19. https://github.com/LBH1024/CAN/models/attention.py
  20. """
  21. from __future__ import absolute_import
  22. from __future__ import division
  23. from __future__ import print_function
  24. import paddle.nn as nn
  25. import paddle
  26. import math
  27. """
  28. Counting Module
  29. """
  30. class ChannelAtt(nn.Layer):
  31. def __init__(self, channel, reduction):
  32. super(ChannelAtt, self).__init__()
  33. self.avg_pool = nn.AdaptiveAvgPool2D(1)
  34. self.fc = nn.Sequential(
  35. nn.Linear(channel, channel // reduction),
  36. nn.ReLU(),
  37. nn.Linear(channel // reduction, channel),
  38. nn.Sigmoid(),
  39. )
  40. def forward(self, x):
  41. b, c, _, _ = x.shape
  42. y = paddle.reshape(self.avg_pool(x), [b, c])
  43. y = paddle.reshape(self.fc(y), [b, c, 1, 1])
  44. return x * y
  45. class CountingDecoder(nn.Layer):
  46. def __init__(self, in_channel, out_channel, kernel_size):
  47. super(CountingDecoder, self).__init__()
  48. self.in_channel = in_channel
  49. self.out_channel = out_channel
  50. self.trans_layer = nn.Sequential(
  51. nn.Conv2D(
  52. self.in_channel,
  53. 512,
  54. kernel_size=kernel_size,
  55. padding=kernel_size // 2,
  56. bias_attr=False,
  57. ),
  58. nn.BatchNorm2D(512),
  59. )
  60. self.channel_att = ChannelAtt(512, 16)
  61. self.pred_layer = nn.Sequential(
  62. nn.Conv2D(512, self.out_channel, kernel_size=1, bias_attr=False),
  63. nn.Sigmoid(),
  64. )
  65. def forward(self, x, mask):
  66. b, _, h, w = x.shape
  67. x = self.trans_layer(x)
  68. x = self.channel_att(x)
  69. x = self.pred_layer(x)
  70. if mask is not None:
  71. x = x * mask
  72. x = paddle.reshape(x, [b, self.out_channel, -1])
  73. x1 = paddle.sum(x, axis=-1)
  74. return x1, paddle.reshape(x, [b, self.out_channel, h, w])
  75. """
  76. Attention Decoder
  77. """
  78. class PositionEmbeddingSine(nn.Layer):
  79. def __init__(
  80. self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
  81. ):
  82. super().__init__()
  83. self.num_pos_feats = num_pos_feats
  84. self.temperature = temperature
  85. self.normalize = normalize
  86. if scale is not None and normalize is False:
  87. raise ValueError("normalize should be True if scale is passed")
  88. if scale is None:
  89. scale = 2 * math.pi
  90. self.scale = scale
  91. def forward(self, x, mask):
  92. y_embed = paddle.cumsum(mask, 1, dtype="float32")
  93. x_embed = paddle.cumsum(mask, 2, dtype="float32")
  94. if self.normalize:
  95. eps = 1e-6
  96. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  97. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  98. dim_t = paddle.arange(self.num_pos_feats, dtype="float32")
  99. dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
  100. dim_t = self.temperature ** (
  101. 2 * (dim_t / dim_d).astype("int64") / self.num_pos_feats
  102. )
  103. pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
  104. pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t
  105. pos_x = paddle.flatten(
  106. paddle.stack(
  107. [paddle.sin(pos_x[:, :, :, 0::2]), paddle.cos(pos_x[:, :, :, 1::2])],
  108. axis=4,
  109. ),
  110. 3,
  111. )
  112. pos_y = paddle.flatten(
  113. paddle.stack(
  114. [paddle.sin(pos_y[:, :, :, 0::2]), paddle.cos(pos_y[:, :, :, 1::2])],
  115. axis=4,
  116. ),
  117. 3,
  118. )
  119. pos = paddle.transpose(paddle.concat([pos_y, pos_x], axis=3), [0, 3, 1, 2])
  120. return pos
  121. class AttDecoder(nn.Layer):
  122. def __init__(
  123. self,
  124. ratio,
  125. is_train,
  126. input_size,
  127. hidden_size,
  128. encoder_out_channel,
  129. dropout,
  130. dropout_ratio,
  131. word_num,
  132. counting_decoder_out_channel,
  133. attention,
  134. ):
  135. super(AttDecoder, self).__init__()
  136. self.input_size = input_size
  137. self.hidden_size = hidden_size
  138. self.out_channel = encoder_out_channel
  139. self.attention_dim = attention["attention_dim"]
  140. self.dropout_prob = dropout
  141. self.ratio = ratio
  142. self.word_num = word_num
  143. self.counting_num = counting_decoder_out_channel
  144. self.is_train = is_train
  145. self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
  146. self.embedding = nn.Embedding(self.word_num, self.input_size)
  147. self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
  148. self.word_attention = Attention(hidden_size, attention["attention_dim"])
  149. self.encoder_feature_conv = nn.Conv2D(
  150. self.out_channel,
  151. self.attention_dim,
  152. kernel_size=attention["word_conv_kernel"],
  153. padding=attention["word_conv_kernel"] // 2,
  154. )
  155. self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
  156. self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
  157. self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
  158. self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size)
  159. self.word_convert = nn.Linear(self.hidden_size, self.word_num)
  160. if dropout:
  161. self.dropout = nn.Dropout(dropout_ratio)
  162. def forward(self, cnn_features, labels, counting_preds, images_mask):
  163. if self.is_train:
  164. _, num_steps = labels.shape
  165. else:
  166. num_steps = 36
  167. batch_size, _, height, width = cnn_features.shape
  168. images_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
  169. word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
  170. word_alpha_sum = paddle.zeros((batch_size, 1, height, width))
  171. hidden = self.init_hidden(cnn_features, images_mask)
  172. counting_context_weighted = self.counting_context_weight(counting_preds)
  173. cnn_features_trans = self.encoder_feature_conv(cnn_features)
  174. position_embedding = PositionEmbeddingSine(256, normalize=True)
  175. pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
  176. cnn_features_trans = cnn_features_trans + pos
  177. word = paddle.ones([batch_size, 1], dtype="int64") # init word as sos
  178. word = word.squeeze(axis=1)
  179. for i in range(num_steps):
  180. word_embedding = self.embedding(word)
  181. _, hidden = self.word_input_gru(word_embedding, hidden)
  182. word_context_vec, _, word_alpha_sum = self.word_attention(
  183. cnn_features, cnn_features_trans, hidden, word_alpha_sum, images_mask
  184. )
  185. current_state = self.word_state_weight(hidden)
  186. word_weighted_embedding = self.word_embedding_weight(word_embedding)
  187. word_context_weighted = self.word_context_weight(word_context_vec)
  188. if self.dropout_prob:
  189. word_out_state = self.dropout(
  190. current_state
  191. + word_weighted_embedding
  192. + word_context_weighted
  193. + counting_context_weighted
  194. )
  195. else:
  196. word_out_state = (
  197. current_state
  198. + word_weighted_embedding
  199. + word_context_weighted
  200. + counting_context_weighted
  201. )
  202. word_prob = self.word_convert(word_out_state)
  203. word_probs[:, i] = word_prob
  204. if self.is_train:
  205. word = labels[:, i]
  206. else:
  207. word = word_prob.argmax(1)
  208. word = paddle.multiply(
  209. word, labels[:, i]
  210. ) # labels are oneslike tensor in infer/predict mode
  211. return word_probs
  212. def init_hidden(self, features, feature_mask):
  213. average = paddle.sum(
  214. paddle.sum(features * feature_mask, axis=-1), axis=-1
  215. ) / paddle.sum((paddle.sum(feature_mask, axis=-1)), axis=-1)
  216. average = self.init_weight(average)
  217. return paddle.tanh(average)
  218. """
  219. Attention Module
  220. """
  221. class Attention(nn.Layer):
  222. def __init__(self, hidden_size, attention_dim):
  223. super(Attention, self).__init__()
  224. self.hidden = hidden_size
  225. self.attention_dim = attention_dim
  226. self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
  227. self.attention_conv = nn.Conv2D(
  228. 1, 512, kernel_size=11, padding=5, bias_attr=False
  229. )
  230. self.attention_weight = nn.Linear(512, self.attention_dim, bias_attr=False)
  231. self.alpha_convert = nn.Linear(self.attention_dim, 1)
  232. def forward(
  233. self, cnn_features, cnn_features_trans, hidden, alpha_sum, image_mask=None
  234. ):
  235. query = self.hidden_weight(hidden)
  236. alpha_sum_trans = self.attention_conv(alpha_sum)
  237. coverage_alpha = self.attention_weight(
  238. paddle.transpose(alpha_sum_trans, [0, 2, 3, 1])
  239. )
  240. alpha_score = paddle.tanh(
  241. paddle.unsqueeze(query, [1, 2])
  242. + coverage_alpha
  243. + paddle.transpose(cnn_features_trans, [0, 2, 3, 1])
  244. )
  245. energy = self.alpha_convert(alpha_score)
  246. energy = energy - energy.max()
  247. energy_exp = paddle.exp(paddle.squeeze(energy, -1))
  248. if image_mask is not None:
  249. energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
  250. alpha = energy_exp / (
  251. paddle.unsqueeze(paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10
  252. )
  253. alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
  254. context_vector = paddle.sum(
  255. paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1
  256. )
  257. return context_vector, alpha, alpha_sum
  258. class CANHead(nn.Layer):
  259. def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
  260. super(CANHead, self).__init__()
  261. self.in_channel = in_channel
  262. self.out_channel = out_channel
  263. self.counting_decoder1 = CountingDecoder(
  264. self.in_channel, self.out_channel, 3
  265. ) # mscm
  266. self.counting_decoder2 = CountingDecoder(self.in_channel, self.out_channel, 5)
  267. self.decoder = AttDecoder(ratio, **attdecoder)
  268. self.ratio = ratio
  269. def forward(self, inputs, targets=None):
  270. cnn_features, images_mask, labels = inputs
  271. counting_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
  272. counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
  273. counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
  274. counting_preds = (counting_preds1 + counting_preds2) / 2
  275. word_probs = self.decoder(cnn_features, labels, counting_preds, images_mask)
  276. return word_probs, counting_preds, counting_preds1, counting_preds2