det_ct_loss.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/shengtao96/CentripetalText/tree/main/models/loss
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. from paddle import nn
  23. import paddle.nn.functional as F
  24. import numpy as np
  25. def ohem_single(score, gt_text, training_mask):
  26. # online hard example mining
  27. pos_num = int(paddle.sum(gt_text > 0.5)) - int(
  28. paddle.sum((gt_text > 0.5) & (training_mask <= 0.5))
  29. )
  30. if pos_num == 0:
  31. # selected_mask = gt_text.copy() * 0 # may be not good
  32. selected_mask = training_mask
  33. selected_mask = paddle.cast(
  34. selected_mask.reshape((1, selected_mask.shape[0], selected_mask.shape[1])),
  35. "float32",
  36. )
  37. return selected_mask
  38. neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5)))
  39. neg_num = int(min(pos_num * 3, neg_num))
  40. if neg_num == 0:
  41. selected_mask = training_mask
  42. selected_mask = paddle.cast(
  43. selected_mask.reshape((1, selected_mask.shape[0], selected_mask.shape[1])),
  44. "float32",
  45. )
  46. return selected_mask
  47. # hard example
  48. neg_score = score[(gt_text <= 0.5) & (training_mask > 0.5)]
  49. neg_score_sorted = paddle.sort(-neg_score)
  50. threshold = -neg_score_sorted[neg_num - 1]
  51. selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
  52. selected_mask = paddle.cast(
  53. selected_mask.reshape((1, selected_mask.shape[0], selected_mask.shape[1])),
  54. "float32",
  55. )
  56. return selected_mask
  57. def ohem_batch(scores, gt_texts, training_masks):
  58. selected_masks = []
  59. for i in range(scores.shape[0]):
  60. selected_masks.append(
  61. ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])
  62. )
  63. selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32")
  64. return selected_masks
  65. def iou_single(a, b, mask, n_class):
  66. EPS = 1e-6
  67. valid = mask == 1
  68. a = a[valid]
  69. b = b[valid]
  70. miou = []
  71. # iou of each class
  72. for i in range(n_class):
  73. inter = paddle.cast(((a == i) & (b == i)), "float32")
  74. union = paddle.cast(((a == i) | (b == i)), "float32")
  75. miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
  76. miou = sum(miou) / len(miou)
  77. return miou
  78. def iou(a, b, mask, n_class=2, reduce=True):
  79. batch_size = a.shape[0]
  80. a = a.reshape((batch_size, -1))
  81. b = b.reshape((batch_size, -1))
  82. mask = mask.reshape((batch_size, -1))
  83. iou = paddle.zeros((batch_size,), dtype="float32")
  84. for i in range(batch_size):
  85. iou[i] = iou_single(a[i], b[i], mask[i], n_class)
  86. if reduce:
  87. iou = paddle.mean(iou)
  88. return iou
  89. class DiceLoss(nn.Layer):
  90. def __init__(self, loss_weight=1.0):
  91. super(DiceLoss, self).__init__()
  92. self.loss_weight = loss_weight
  93. def forward(self, input, target, mask, reduce=True):
  94. batch_size = input.shape[0]
  95. input = F.sigmoid(input) # scale to 0-1
  96. input = input.reshape((batch_size, -1))
  97. target = paddle.cast(target.reshape((batch_size, -1)), "float32")
  98. mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
  99. input = input * mask
  100. target = target * mask
  101. a = paddle.sum(input * target, axis=1)
  102. b = paddle.sum(input * input, axis=1) + 0.001
  103. c = paddle.sum(target * target, axis=1) + 0.001
  104. d = (2 * a) / (b + c)
  105. loss = 1 - d
  106. loss = self.loss_weight * loss
  107. if reduce:
  108. loss = paddle.mean(loss)
  109. return loss
  110. class SmoothL1Loss(nn.Layer):
  111. def __init__(self, beta=1.0, loss_weight=1.0):
  112. super(SmoothL1Loss, self).__init__()
  113. self.beta = beta
  114. self.loss_weight = loss_weight
  115. np_coord = np.zeros(shape=[640, 640, 2], dtype=np.int64)
  116. for i in range(640):
  117. for j in range(640):
  118. np_coord[i, j, 0] = j
  119. np_coord[i, j, 1] = i
  120. np_coord = np_coord.reshape((-1, 2))
  121. self.coord = self.create_parameter(
  122. shape=[640 * 640, 2],
  123. dtype="int32", # NOTE: not support "int64" before paddle 2.3.1
  124. default_initializer=nn.initializer.Assign(value=np_coord),
  125. )
  126. self.coord.stop_gradient = True
  127. def forward_single(self, input, target, mask, beta=1.0, eps=1e-6):
  128. batch_size = input.shape[0]
  129. diff = paddle.abs(input - target) * mask.unsqueeze(1)
  130. loss = paddle.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)
  131. loss = paddle.cast(loss.reshape((batch_size, -1)), "float32")
  132. mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
  133. loss = paddle.sum(loss, axis=-1)
  134. loss = loss / (mask.sum(axis=-1) + eps)
  135. return loss
  136. def select_single(self, distance, gt_instance, gt_kernel_instance, training_mask):
  137. with paddle.no_grad():
  138. # paddle 2.3.1, paddle.slice not support:
  139. # distance[:, self.coord[:, 1], self.coord[:, 0]]
  140. select_distance_list = []
  141. for i in range(2):
  142. tmp1 = distance[i, :]
  143. tmp2 = tmp1[self.coord[:, 1], self.coord[:, 0]]
  144. select_distance_list.append(tmp2.unsqueeze(0))
  145. select_distance = paddle.concat(select_distance_list, axis=0)
  146. off_points = paddle.cast(
  147. self.coord, "float32"
  148. ) + 10 * select_distance.transpose((1, 0))
  149. off_points = paddle.cast(off_points, "int64")
  150. off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1)
  151. selected_mask = (
  152. gt_instance[self.coord[:, 1], self.coord[:, 0]]
  153. != gt_kernel_instance[off_points[:, 1], off_points[:, 0]]
  154. )
  155. selected_mask = paddle.cast(
  156. selected_mask.reshape((1, -1, distance.shape[-1])), "int64"
  157. )
  158. selected_training_mask = selected_mask * training_mask
  159. return selected_training_mask
  160. def forward(
  161. self,
  162. distances,
  163. gt_instances,
  164. gt_kernel_instances,
  165. training_masks,
  166. gt_distances,
  167. reduce=True,
  168. ):
  169. selected_training_masks = []
  170. for i in range(distances.shape[0]):
  171. selected_training_masks.append(
  172. self.select_single(
  173. distances[i, :, :, :],
  174. gt_instances[i, :, :],
  175. gt_kernel_instances[i, :, :],
  176. training_masks[i, :, :],
  177. )
  178. )
  179. selected_training_masks = paddle.cast(
  180. paddle.concat(selected_training_masks, 0), "float32"
  181. )
  182. loss = self.forward_single(
  183. distances, gt_distances, selected_training_masks, self.beta
  184. )
  185. loss = self.loss_weight * loss
  186. with paddle.no_grad():
  187. batch_size = distances.shape[0]
  188. false_num = selected_training_masks.reshape((batch_size, -1))
  189. false_num = false_num.sum(axis=-1)
  190. total_num = paddle.cast(training_masks.reshape((batch_size, -1)), "float32")
  191. total_num = total_num.sum(axis=-1)
  192. iou_text = (total_num - false_num) / (total_num + 1e-6)
  193. if reduce:
  194. loss = paddle.mean(loss)
  195. return loss, iou_text
  196. class CTLoss(nn.Layer):
  197. def __init__(self):
  198. super(CTLoss, self).__init__()
  199. self.kernel_loss = DiceLoss()
  200. self.loc_loss = SmoothL1Loss(beta=0.1, loss_weight=0.05)
  201. def forward(self, preds, batch):
  202. imgs = batch[0]
  203. out = preds["maps"]
  204. (
  205. gt_kernels,
  206. training_masks,
  207. gt_instances,
  208. gt_kernel_instances,
  209. training_mask_distances,
  210. gt_distances,
  211. ) = batch[1:]
  212. kernels = out[:, 0, :, :]
  213. distances = out[:, 1:, :, :]
  214. # kernel loss
  215. selected_masks = ohem_batch(kernels, gt_kernels, training_masks)
  216. loss_kernel = self.kernel_loss(
  217. kernels, gt_kernels, selected_masks, reduce=False
  218. )
  219. iou_kernel = iou(
  220. paddle.cast((kernels > 0), "int64"),
  221. gt_kernels,
  222. training_masks,
  223. reduce=False,
  224. )
  225. losses = dict(
  226. loss_kernels=loss_kernel,
  227. )
  228. # loc loss
  229. loss_loc, iou_text = self.loc_loss(
  230. distances,
  231. gt_instances,
  232. gt_kernel_instances,
  233. training_mask_distances,
  234. gt_distances,
  235. reduce=False,
  236. )
  237. losses.update(
  238. dict(
  239. loss_loc=loss_loc,
  240. )
  241. )
  242. loss_all = loss_kernel + loss_loc
  243. losses = {"loss": loss_all}
  244. return losses