det_fce_loss.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py
  17. """
  18. import numpy as np
  19. from paddle import nn
  20. import paddle
  21. import paddle.nn.functional as F
  22. from functools import partial
  23. def multi_apply(func, *args, **kwargs):
  24. pfunc = partial(func, **kwargs) if kwargs else func
  25. map_results = map(pfunc, *args)
  26. return tuple(map(list, zip(*map_results)))
  27. class FCELoss(nn.Layer):
  28. """The class for implementing FCENet loss
  29. FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
  30. Text Detection
  31. [https://arxiv.org/abs/2104.10442]
  32. Args:
  33. fourier_degree (int) : The maximum Fourier transform degree k.
  34. num_sample (int) : The sampling points number of regression
  35. loss. If it is too small, fcenet tends to be overfitting.
  36. ohem_ratio (float): the negative/positive ratio in OHEM.
  37. """
  38. def __init__(self, fourier_degree, num_sample, ohem_ratio=3.0):
  39. super().__init__()
  40. self.fourier_degree = fourier_degree
  41. self.num_sample = num_sample
  42. self.ohem_ratio = ohem_ratio
  43. def forward(self, preds, labels):
  44. assert isinstance(preds, dict)
  45. preds = preds["levels"]
  46. p3_maps, p4_maps, p5_maps = labels[1:]
  47. assert (
  48. p3_maps[0].shape[0] == 4 * self.fourier_degree + 5
  49. ), "fourier degree not equal in FCEhead and FCEtarget"
  50. # to tensor
  51. gts = [p3_maps, p4_maps, p5_maps]
  52. for idx, maps in enumerate(gts):
  53. gts[idx] = paddle.to_tensor(np.stack(maps))
  54. losses = multi_apply(self.forward_single, preds, gts)
  55. loss_tr = paddle.to_tensor(0.0).astype("float32")
  56. loss_tcl = paddle.to_tensor(0.0).astype("float32")
  57. loss_reg_x = paddle.to_tensor(0.0).astype("float32")
  58. loss_reg_y = paddle.to_tensor(0.0).astype("float32")
  59. loss_all = paddle.to_tensor(0.0).astype("float32")
  60. for idx, loss in enumerate(losses):
  61. loss_all += sum(loss)
  62. if idx == 0:
  63. loss_tr += sum(loss)
  64. elif idx == 1:
  65. loss_tcl += sum(loss)
  66. elif idx == 2:
  67. loss_reg_x += sum(loss)
  68. else:
  69. loss_reg_y += sum(loss)
  70. results = dict(
  71. loss=loss_all,
  72. loss_text=loss_tr,
  73. loss_center=loss_tcl,
  74. loss_reg_x=loss_reg_x,
  75. loss_reg_y=loss_reg_y,
  76. )
  77. return results
  78. def forward_single(self, pred, gt):
  79. cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1))
  80. reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1))
  81. gt = paddle.transpose(gt, (0, 2, 3, 1))
  82. k = 2 * self.fourier_degree + 1
  83. tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
  84. tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
  85. x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
  86. y_pred = paddle.reshape(reg_pred[:, :, :, k : 2 * k], (-1, k))
  87. tr_mask = gt[:, :, :, :1].reshape([-1])
  88. tcl_mask = gt[:, :, :, 1:2].reshape([-1])
  89. train_mask = gt[:, :, :, 2:3].reshape([-1])
  90. x_map = paddle.reshape(gt[:, :, :, 3 : 3 + k], (-1, k))
  91. y_map = paddle.reshape(gt[:, :, :, 3 + k :], (-1, k))
  92. tr_train_mask = (train_mask * tr_mask).astype("bool")
  93. tr_train_mask2 = paddle.concat(
  94. [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1
  95. )
  96. # tr loss
  97. loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
  98. # tcl loss
  99. loss_tcl = paddle.to_tensor(0.0).astype("float32")
  100. tr_neg_mask = tr_train_mask.logical_not()
  101. tr_neg_mask2 = paddle.concat(
  102. [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1
  103. )
  104. if tr_train_mask.sum().item() > 0:
  105. loss_tcl_pos = F.cross_entropy(
  106. tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
  107. tcl_mask.masked_select(tr_train_mask).astype("int64"),
  108. )
  109. loss_tcl_neg = F.cross_entropy(
  110. tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
  111. tcl_mask.masked_select(tr_neg_mask).astype("int64"),
  112. )
  113. loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
  114. # regression loss
  115. loss_reg_x = paddle.to_tensor(0.0).astype("float32")
  116. loss_reg_y = paddle.to_tensor(0.0).astype("float32")
  117. if tr_train_mask.sum().item() > 0:
  118. weight = (
  119. tr_mask.masked_select(tr_train_mask.astype("bool")).astype("float32")
  120. + tcl_mask.masked_select(tr_train_mask.astype("bool")).astype("float32")
  121. ) / 2
  122. weight = weight.reshape([-1, 1])
  123. ft_x, ft_y = self.fourier2poly(x_map, y_map)
  124. ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
  125. dim = ft_x.shape[1]
  126. tr_train_mask3 = paddle.concat(
  127. [tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1
  128. )
  129. loss_reg_x = paddle.mean(
  130. weight
  131. * F.smooth_l1_loss(
  132. ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
  133. ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
  134. reduction="none",
  135. )
  136. )
  137. loss_reg_y = paddle.mean(
  138. weight
  139. * F.smooth_l1_loss(
  140. ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
  141. ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
  142. reduction="none",
  143. )
  144. )
  145. return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
  146. def ohem(self, predict, target, train_mask):
  147. pos = (target * train_mask).astype("bool")
  148. neg = ((1 - target) * train_mask).astype("bool")
  149. pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
  150. neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
  151. n_pos = pos.astype("float32").sum()
  152. if n_pos.item() > 0:
  153. loss_pos = F.cross_entropy(
  154. predict.masked_select(pos2).reshape([-1, 2]),
  155. target.masked_select(pos).astype("int64"),
  156. reduction="sum",
  157. )
  158. loss_neg = F.cross_entropy(
  159. predict.masked_select(neg2).reshape([-1, 2]),
  160. target.masked_select(neg).astype("int64"),
  161. reduction="none",
  162. )
  163. n_neg = min(
  164. int(neg.astype("float32").sum().item()),
  165. int(self.ohem_ratio * n_pos.astype("float32")),
  166. )
  167. else:
  168. loss_pos = paddle.to_tensor(0.0)
  169. loss_neg = F.cross_entropy(
  170. predict.masked_select(neg2).reshape([-1, 2]),
  171. target.masked_select(neg).astype("int64"),
  172. reduction="none",
  173. )
  174. n_neg = 100
  175. if len(loss_neg) > n_neg:
  176. loss_neg, _ = paddle.topk(loss_neg, n_neg)
  177. return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype("float32")
  178. def fourier2poly(self, real_maps, imag_maps):
  179. """Transform Fourier coefficient maps to polygon maps.
  180. Args:
  181. real_maps (tensor): A map composed of the real parts of the
  182. Fourier coefficients, whose shape is (-1, 2k+1)
  183. imag_maps (tensor):A map composed of the imag parts of the
  184. Fourier coefficients, whose shape is (-1, 2k+1)
  185. Returns
  186. x_maps (tensor): A map composed of the x value of the polygon
  187. represented by n sample points (xn, yn), whose shape is (-1, n)
  188. y_maps (tensor): A map composed of the y value of the polygon
  189. represented by n sample points (xn, yn), whose shape is (-1, n)
  190. """
  191. k_vect = paddle.arange(
  192. -self.fourier_degree, self.fourier_degree + 1, dtype="float32"
  193. ).reshape([-1, 1])
  194. i_vect = paddle.arange(0, self.num_sample, dtype="float32").reshape([1, -1])
  195. transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect, i_vect)
  196. x1 = paddle.einsum("ak, kn-> an", real_maps, paddle.cos(transform_matrix))
  197. x2 = paddle.einsum("ak, kn-> an", imag_maps, paddle.sin(transform_matrix))
  198. y1 = paddle.einsum("ak, kn-> an", real_maps, paddle.sin(transform_matrix))
  199. y2 = paddle.einsum("ak, kn-> an", imag_maps, paddle.cos(transform_matrix))
  200. x_maps = x1 - x2
  201. y_maps = y1 + y2
  202. return x_maps, y_maps