distillation_loss.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. import numpy as np
  18. import cv2
  19. from .rec_ctc_loss import CTCLoss
  20. from .rec_sar_loss import SARLoss
  21. from .rec_ce_loss import CELoss
  22. from .basic_loss import DMLLoss, KLDivLoss, DKDLoss
  23. from .basic_loss import DistanceLoss
  24. from .basic_loss import LossFromOutput
  25. from .det_db_loss import DBLoss
  26. from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
  27. from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
  28. def _sum_loss(loss_dict):
  29. if "loss" in loss_dict.keys():
  30. return loss_dict
  31. else:
  32. loss_dict["loss"] = 0.0
  33. for k, value in loss_dict.items():
  34. if k == "loss":
  35. continue
  36. else:
  37. loss_dict["loss"] += value
  38. return loss_dict
  39. class DistillationDMLLoss(DMLLoss):
  40. """ """
  41. def __init__(
  42. self,
  43. model_name_pairs=[],
  44. act=None,
  45. use_log=False,
  46. key=None,
  47. multi_head=False,
  48. dis_head="ctc",
  49. maps_name=None,
  50. name="dml",
  51. ):
  52. super().__init__(act=act, use_log=use_log)
  53. assert isinstance(model_name_pairs, list)
  54. self.key = key
  55. self.multi_head = multi_head
  56. self.dis_head = dis_head
  57. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  58. self.name = name
  59. self.maps_name = self._check_maps_name(maps_name)
  60. def _check_model_name_pairs(self, model_name_pairs):
  61. if not isinstance(model_name_pairs, list):
  62. return []
  63. elif isinstance(model_name_pairs[0], list) and isinstance(
  64. model_name_pairs[0][0], str
  65. ):
  66. return model_name_pairs
  67. else:
  68. return [model_name_pairs]
  69. def _check_maps_name(self, maps_name):
  70. if maps_name is None:
  71. return None
  72. elif isinstance(maps_name, str):
  73. return [maps_name]
  74. elif isinstance(maps_name, list):
  75. return [maps_name]
  76. else:
  77. return None
  78. def _slice_out(self, outs):
  79. new_outs = {}
  80. for k in self.maps_name:
  81. if k == "thrink_maps":
  82. new_outs[k] = outs[:, 0, :, :]
  83. elif k == "threshold_maps":
  84. new_outs[k] = outs[:, 1, :, :]
  85. elif k == "binary_maps":
  86. new_outs[k] = outs[:, 2, :, :]
  87. else:
  88. continue
  89. return new_outs
  90. def forward(self, predicts, batch):
  91. loss_dict = dict()
  92. for idx, pair in enumerate(self.model_name_pairs):
  93. out1 = predicts[pair[0]]
  94. out2 = predicts[pair[1]]
  95. if self.key is not None:
  96. out1 = out1[self.key]
  97. out2 = out2[self.key]
  98. if self.maps_name is None:
  99. if self.multi_head:
  100. loss = super().forward(out1[self.dis_head], out2[self.dis_head])
  101. else:
  102. loss = super().forward(out1, out2)
  103. if isinstance(loss, dict):
  104. for key in loss:
  105. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = (
  106. loss[key]
  107. )
  108. else:
  109. loss_dict["{}_{}".format(self.name, idx)] = loss
  110. else:
  111. outs1 = self._slice_out(out1)
  112. outs2 = self._slice_out(out2)
  113. for _c, k in enumerate(outs1.keys()):
  114. loss = super().forward(outs1[k], outs2[k])
  115. if isinstance(loss, dict):
  116. for key in loss:
  117. loss_dict[
  118. "{}_{}_{}_{}_{}".format(
  119. key, pair[0], pair[1], self.maps_name, idx
  120. )
  121. ] = loss[key]
  122. else:
  123. loss_dict[
  124. "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
  125. ] = loss
  126. loss_dict = _sum_loss(loss_dict)
  127. return loss_dict
  128. class DistillationKLDivLoss(KLDivLoss):
  129. """ """
  130. def __init__(
  131. self,
  132. model_name_pairs=[],
  133. key=None,
  134. multi_head=False,
  135. dis_head="ctc",
  136. maps_name=None,
  137. name="kl_div",
  138. ):
  139. super().__init__()
  140. assert isinstance(model_name_pairs, list)
  141. self.key = key
  142. self.multi_head = multi_head
  143. self.dis_head = dis_head
  144. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  145. self.name = name
  146. self.maps_name = self._check_maps_name(maps_name)
  147. def _check_model_name_pairs(self, model_name_pairs):
  148. if not isinstance(model_name_pairs, list):
  149. return []
  150. elif isinstance(model_name_pairs[0], list) and isinstance(
  151. model_name_pairs[0][0], str
  152. ):
  153. return model_name_pairs
  154. else:
  155. return [model_name_pairs]
  156. def _check_maps_name(self, maps_name):
  157. if maps_name is None:
  158. return None
  159. elif isinstance(maps_name, str):
  160. return [maps_name]
  161. elif isinstance(maps_name, list):
  162. return [maps_name]
  163. else:
  164. return None
  165. def _slice_out(self, outs):
  166. new_outs = {}
  167. for k in self.maps_name:
  168. if k == "thrink_maps":
  169. new_outs[k] = outs[:, 0, :, :]
  170. elif k == "threshold_maps":
  171. new_outs[k] = outs[:, 1, :, :]
  172. elif k == "binary_maps":
  173. new_outs[k] = outs[:, 2, :, :]
  174. else:
  175. continue
  176. return new_outs
  177. def forward(self, predicts, batch):
  178. loss_dict = dict()
  179. for idx, pair in enumerate(self.model_name_pairs):
  180. out1 = predicts[pair[0]]
  181. out2 = predicts[pair[1]]
  182. if self.key is not None:
  183. out1 = out1[self.key]
  184. out2 = out2[self.key]
  185. if self.maps_name is None:
  186. if self.multi_head:
  187. # for nrtr dml loss
  188. max_len = batch[3].max()
  189. tgt = batch[2][:, 1 : 2 + max_len]
  190. tgt = tgt.reshape([-1])
  191. non_pad_mask = paddle.not_equal(
  192. tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
  193. )
  194. loss = super().forward(
  195. out1[self.dis_head], out2[self.dis_head], non_pad_mask
  196. )
  197. else:
  198. loss = super().forward(out1, out2)
  199. if isinstance(loss, dict):
  200. for key in loss:
  201. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = (
  202. loss[key]
  203. )
  204. else:
  205. loss_dict["{}_{}".format(self.name, idx)] = loss
  206. else:
  207. outs1 = self._slice_out(out1)
  208. outs2 = self._slice_out(out2)
  209. for _c, k in enumerate(outs1.keys()):
  210. loss = super().forward(outs1[k], outs2[k])
  211. if isinstance(loss, dict):
  212. for key in loss:
  213. loss_dict[
  214. "{}_{}_{}_{}_{}".format(
  215. key, pair[0], pair[1], self.maps_name, idx
  216. )
  217. ] = loss[key]
  218. else:
  219. loss_dict[
  220. "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
  221. ] = loss
  222. loss_dict = _sum_loss(loss_dict)
  223. return loss_dict
  224. class DistillationDKDLoss(DKDLoss):
  225. """ """
  226. def __init__(
  227. self,
  228. model_name_pairs=[],
  229. key=None,
  230. multi_head=False,
  231. dis_head="ctc",
  232. maps_name=None,
  233. name="dkd",
  234. temperature=1.0,
  235. alpha=1.0,
  236. beta=1.0,
  237. ):
  238. super().__init__(temperature, alpha, beta)
  239. assert isinstance(model_name_pairs, list)
  240. self.key = key
  241. self.multi_head = multi_head
  242. self.dis_head = dis_head
  243. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  244. self.name = name
  245. self.maps_name = self._check_maps_name(maps_name)
  246. def _check_model_name_pairs(self, model_name_pairs):
  247. if not isinstance(model_name_pairs, list):
  248. return []
  249. elif isinstance(model_name_pairs[0], list) and isinstance(
  250. model_name_pairs[0][0], str
  251. ):
  252. return model_name_pairs
  253. else:
  254. return [model_name_pairs]
  255. def _check_maps_name(self, maps_name):
  256. if maps_name is None:
  257. return None
  258. elif isinstance(maps_name, str):
  259. return [maps_name]
  260. elif isinstance(maps_name, list):
  261. return [maps_name]
  262. else:
  263. return None
  264. def _slice_out(self, outs):
  265. new_outs = {}
  266. for k in self.maps_name:
  267. if k == "thrink_maps":
  268. new_outs[k] = outs[:, 0, :, :]
  269. elif k == "threshold_maps":
  270. new_outs[k] = outs[:, 1, :, :]
  271. elif k == "binary_maps":
  272. new_outs[k] = outs[:, 2, :, :]
  273. else:
  274. continue
  275. return new_outs
  276. def forward(self, predicts, batch):
  277. loss_dict = dict()
  278. for idx, pair in enumerate(self.model_name_pairs):
  279. out1 = predicts[pair[0]]
  280. out2 = predicts[pair[1]]
  281. if self.key is not None:
  282. out1 = out1[self.key]
  283. out2 = out2[self.key]
  284. if self.maps_name is None:
  285. if self.multi_head:
  286. # for nrtr dml loss
  287. max_len = batch[3].max()
  288. tgt = batch[2][:, 1 : 2 + max_len] # [batch_size, max_len + 1]
  289. tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
  290. non_pad_mask = paddle.not_equal(
  291. tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
  292. ) # batch_size * (max_len + 1)
  293. loss = super().forward(
  294. out1[self.dis_head], out2[self.dis_head], tgt, non_pad_mask
  295. ) # [batch_size, max_len + 1, num_char]
  296. else:
  297. loss = super().forward(out1, out2)
  298. if isinstance(loss, dict):
  299. for key in loss:
  300. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = (
  301. loss[key]
  302. )
  303. else:
  304. loss_dict["{}_{}".format(self.name, idx)] = loss
  305. else:
  306. outs1 = self._slice_out(out1)
  307. outs2 = self._slice_out(out2)
  308. for _c, k in enumerate(outs1.keys()):
  309. loss = super().forward(outs1[k], outs2[k])
  310. if isinstance(loss, dict):
  311. for key in loss:
  312. loss_dict[
  313. "{}_{}_{}_{}_{}".format(
  314. key, pair[0], pair[1], self.maps_name, idx
  315. )
  316. ] = loss[key]
  317. else:
  318. loss_dict[
  319. "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
  320. ] = loss
  321. loss_dict = _sum_loss(loss_dict)
  322. return loss_dict
  323. class DistillationNRTRDMLLoss(DistillationDMLLoss):
  324. """ """
  325. def forward(self, predicts, batch):
  326. loss_dict = dict()
  327. for idx, pair in enumerate(self.model_name_pairs):
  328. out1 = predicts[pair[0]]
  329. out2 = predicts[pair[1]]
  330. if self.key is not None:
  331. out1 = out1[self.key]
  332. out2 = out2[self.key]
  333. if self.multi_head:
  334. # for nrtr dml loss
  335. max_len = batch[3].max()
  336. tgt = batch[2][:, 1 : 2 + max_len]
  337. tgt = tgt.reshape([-1])
  338. non_pad_mask = paddle.not_equal(
  339. tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
  340. )
  341. loss = super().forward(
  342. out1[self.dis_head], out2[self.dis_head], non_pad_mask
  343. )
  344. else:
  345. loss = super().forward(out1, out2)
  346. if isinstance(loss, dict):
  347. for key in loss:
  348. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = loss[
  349. key
  350. ]
  351. else:
  352. loss_dict["{}_{}".format(self.name, idx)] = loss
  353. loss_dict = _sum_loss(loss_dict)
  354. return loss_dict
  355. class DistillationKLDivLoss(KLDivLoss):
  356. """ """
  357. def __init__(
  358. self,
  359. model_name_pairs=[],
  360. key=None,
  361. multi_head=False,
  362. dis_head="ctc",
  363. maps_name=None,
  364. name="kl_div",
  365. ):
  366. super().__init__()
  367. assert isinstance(model_name_pairs, list)
  368. self.key = key
  369. self.multi_head = multi_head
  370. self.dis_head = dis_head
  371. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  372. self.name = name
  373. self.maps_name = self._check_maps_name(maps_name)
  374. def _check_model_name_pairs(self, model_name_pairs):
  375. if not isinstance(model_name_pairs, list):
  376. return []
  377. elif isinstance(model_name_pairs[0], list) and isinstance(
  378. model_name_pairs[0][0], str
  379. ):
  380. return model_name_pairs
  381. else:
  382. return [model_name_pairs]
  383. def _check_maps_name(self, maps_name):
  384. if maps_name is None:
  385. return None
  386. elif isinstance(maps_name, str):
  387. return [maps_name]
  388. elif isinstance(maps_name, list):
  389. return [maps_name]
  390. else:
  391. return None
  392. def _slice_out(self, outs):
  393. new_outs = {}
  394. for k in self.maps_name:
  395. if k == "thrink_maps":
  396. new_outs[k] = outs[:, 0, :, :]
  397. elif k == "threshold_maps":
  398. new_outs[k] = outs[:, 1, :, :]
  399. elif k == "binary_maps":
  400. new_outs[k] = outs[:, 2, :, :]
  401. else:
  402. continue
  403. return new_outs
  404. def forward(self, predicts, batch):
  405. loss_dict = dict()
  406. for idx, pair in enumerate(self.model_name_pairs):
  407. out1 = predicts[pair[0]]
  408. out2 = predicts[pair[1]]
  409. if self.key is not None:
  410. out1 = out1[self.key]
  411. out2 = out2[self.key]
  412. if self.maps_name is None:
  413. if self.multi_head:
  414. # for nrtr dml loss
  415. max_len = batch[3].max()
  416. tgt = batch[2][:, 1 : 2 + max_len]
  417. tgt = tgt.reshape([-1])
  418. non_pad_mask = paddle.not_equal(
  419. tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
  420. )
  421. loss = super().forward(
  422. out1[self.dis_head], out2[self.dis_head], non_pad_mask
  423. )
  424. else:
  425. loss = super().forward(out1, out2)
  426. if isinstance(loss, dict):
  427. for key in loss:
  428. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = (
  429. loss[key]
  430. )
  431. else:
  432. loss_dict["{}_{}".format(self.name, idx)] = loss
  433. else:
  434. outs1 = self._slice_out(out1)
  435. outs2 = self._slice_out(out2)
  436. for _c, k in enumerate(outs1.keys()):
  437. loss = super().forward(outs1[k], outs2[k])
  438. if isinstance(loss, dict):
  439. for key in loss:
  440. loss_dict[
  441. "{}_{}_{}_{}_{}".format(
  442. key, pair[0], pair[1], self.maps_name, idx
  443. )
  444. ] = loss[key]
  445. else:
  446. loss_dict[
  447. "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
  448. ] = loss
  449. loss_dict = _sum_loss(loss_dict)
  450. return loss_dict
  451. class DistillationDKDLoss(DKDLoss):
  452. """ """
  453. def __init__(
  454. self,
  455. model_name_pairs=[],
  456. key=None,
  457. multi_head=False,
  458. dis_head="ctc",
  459. maps_name=None,
  460. name="dkd",
  461. temperature=1.0,
  462. alpha=1.0,
  463. beta=1.0,
  464. ):
  465. super().__init__(temperature, alpha, beta)
  466. assert isinstance(model_name_pairs, list)
  467. self.key = key
  468. self.multi_head = multi_head
  469. self.dis_head = dis_head
  470. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  471. self.name = name
  472. self.maps_name = self._check_maps_name(maps_name)
  473. def _check_model_name_pairs(self, model_name_pairs):
  474. if not isinstance(model_name_pairs, list):
  475. return []
  476. elif isinstance(model_name_pairs[0], list) and isinstance(
  477. model_name_pairs[0][0], str
  478. ):
  479. return model_name_pairs
  480. else:
  481. return [model_name_pairs]
  482. def _check_maps_name(self, maps_name):
  483. if maps_name is None:
  484. return None
  485. elif isinstance(maps_name, str):
  486. return [maps_name]
  487. elif isinstance(maps_name, list):
  488. return [maps_name]
  489. else:
  490. return None
  491. def _slice_out(self, outs):
  492. new_outs = {}
  493. for k in self.maps_name:
  494. if k == "thrink_maps":
  495. new_outs[k] = outs[:, 0, :, :]
  496. elif k == "threshold_maps":
  497. new_outs[k] = outs[:, 1, :, :]
  498. elif k == "binary_maps":
  499. new_outs[k] = outs[:, 2, :, :]
  500. else:
  501. continue
  502. return new_outs
  503. def forward(self, predicts, batch):
  504. loss_dict = dict()
  505. for idx, pair in enumerate(self.model_name_pairs):
  506. out1 = predicts[pair[0]]
  507. out2 = predicts[pair[1]]
  508. if self.key is not None:
  509. out1 = out1[self.key]
  510. out2 = out2[self.key]
  511. if self.maps_name is None:
  512. if self.multi_head:
  513. # for nrtr dml loss
  514. max_len = batch[3].max()
  515. tgt = batch[2][:, 1 : 2 + max_len] # [batch_size, max_len + 1]
  516. tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
  517. non_pad_mask = paddle.not_equal(
  518. tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
  519. ) # batch_size * (max_len + 1)
  520. loss = super().forward(
  521. out1[self.dis_head], out2[self.dis_head], tgt, non_pad_mask
  522. ) # [batch_size, max_len + 1, num_char]
  523. else:
  524. loss = super().forward(out1, out2)
  525. if isinstance(loss, dict):
  526. for key in loss:
  527. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], idx)] = (
  528. loss[key]
  529. )
  530. else:
  531. loss_dict["{}_{}".format(self.name, idx)] = loss
  532. else:
  533. outs1 = self._slice_out(out1)
  534. outs2 = self._slice_out(out2)
  535. for _c, k in enumerate(outs1.keys()):
  536. loss = super().forward(outs1[k], outs2[k])
  537. if isinstance(loss, dict):
  538. for key in loss:
  539. loss_dict[
  540. "{}_{}_{}_{}_{}".format(
  541. key, pair[0], pair[1], self.maps_name, idx
  542. )
  543. ] = loss[key]
  544. else:
  545. loss_dict[
  546. "{}_{}_{}".format(self.name, self.maps_name[_c], idx)
  547. ] = loss
  548. loss_dict = _sum_loss(loss_dict)
  549. return loss_dict
  550. class DistillationCTCLoss(CTCLoss):
  551. def __init__(self, model_name_list=[], key=None, multi_head=False, name="loss_ctc"):
  552. super().__init__()
  553. self.model_name_list = model_name_list
  554. self.key = key
  555. self.name = name
  556. self.multi_head = multi_head
  557. def forward(self, predicts, batch):
  558. loss_dict = dict()
  559. for idx, model_name in enumerate(self.model_name_list):
  560. out = predicts[model_name]
  561. if self.key is not None:
  562. out = out[self.key]
  563. if self.multi_head:
  564. assert "ctc" in out, "multi head has multi out"
  565. loss = super().forward(out["ctc"], batch[:2] + batch[3:])
  566. else:
  567. loss = super().forward(out, batch)
  568. if isinstance(loss, dict):
  569. for key in loss:
  570. loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
  571. else:
  572. loss_dict["{}_{}".format(self.name, model_name)] = loss
  573. return loss_dict
  574. class DistillationSARLoss(SARLoss):
  575. def __init__(
  576. self, model_name_list=[], key=None, multi_head=False, name="loss_sar", **kwargs
  577. ):
  578. ignore_index = kwargs.get("ignore_index", 92)
  579. super().__init__(ignore_index=ignore_index)
  580. self.model_name_list = model_name_list
  581. self.key = key
  582. self.name = name
  583. self.multi_head = multi_head
  584. def forward(self, predicts, batch):
  585. loss_dict = dict()
  586. for idx, model_name in enumerate(self.model_name_list):
  587. out = predicts[model_name]
  588. if self.key is not None:
  589. out = out[self.key]
  590. if self.multi_head:
  591. assert "sar" in out, "multi head has multi out"
  592. loss = super().forward(out["sar"], batch[:1] + batch[2:])
  593. else:
  594. loss = super().forward(out, batch)
  595. if isinstance(loss, dict):
  596. for key in loss:
  597. loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
  598. else:
  599. loss_dict["{}_{}".format(self.name, model_name)] = loss
  600. return loss_dict
  601. class DistillationNRTRLoss(CELoss):
  602. def __init__(
  603. self,
  604. model_name_list=[],
  605. key=None,
  606. multi_head=False,
  607. smoothing=True,
  608. name="loss_nrtr",
  609. **kwargs,
  610. ):
  611. super().__init__(smoothing=smoothing)
  612. self.model_name_list = model_name_list
  613. self.key = key
  614. self.name = name
  615. self.multi_head = multi_head
  616. def forward(self, predicts, batch):
  617. loss_dict = dict()
  618. for idx, model_name in enumerate(self.model_name_list):
  619. out = predicts[model_name]
  620. if self.key is not None:
  621. out = out[self.key]
  622. if self.multi_head:
  623. assert "gtc" in out, "multi head has multi out"
  624. loss = super().forward(out["gtc"], batch[:1] + batch[2:])
  625. else:
  626. loss = super().forward(out, batch)
  627. if isinstance(loss, dict):
  628. for key in loss:
  629. loss_dict["{}_{}_{}".format(self.name, model_name, idx)] = loss[key]
  630. else:
  631. loss_dict["{}_{}".format(self.name, model_name)] = loss
  632. return loss_dict
  633. class DistillationDBLoss(DBLoss):
  634. def __init__(
  635. self,
  636. model_name_list=[],
  637. balance_loss=True,
  638. main_loss_type="DiceLoss",
  639. alpha=5,
  640. beta=10,
  641. ohem_ratio=3,
  642. eps=1e-6,
  643. name="db",
  644. **kwargs,
  645. ):
  646. super().__init__()
  647. self.model_name_list = model_name_list
  648. self.name = name
  649. self.key = None
  650. def forward(self, predicts, batch):
  651. loss_dict = {}
  652. for idx, model_name in enumerate(self.model_name_list):
  653. out = predicts[model_name]
  654. if self.key is not None:
  655. out = out[self.key]
  656. loss = super().forward(out, batch)
  657. if isinstance(loss, dict):
  658. for key in loss.keys():
  659. if key == "loss":
  660. continue
  661. name = "{}_{}_{}".format(self.name, model_name, key)
  662. loss_dict[name] = loss[key]
  663. else:
  664. loss_dict["{}_{}".format(self.name, model_name)] = loss
  665. loss_dict = _sum_loss(loss_dict)
  666. return loss_dict
  667. class DistillationDilaDBLoss(DBLoss):
  668. def __init__(
  669. self,
  670. model_name_pairs=[],
  671. key=None,
  672. balance_loss=True,
  673. main_loss_type="DiceLoss",
  674. alpha=5,
  675. beta=10,
  676. ohem_ratio=3,
  677. eps=1e-6,
  678. name="dila_dbloss",
  679. ):
  680. super().__init__()
  681. self.model_name_pairs = model_name_pairs
  682. self.name = name
  683. self.key = key
  684. def forward(self, predicts, batch):
  685. loss_dict = dict()
  686. for idx, pair in enumerate(self.model_name_pairs):
  687. stu_outs = predicts[pair[0]]
  688. tch_outs = predicts[pair[1]]
  689. if self.key is not None:
  690. stu_preds = stu_outs[self.key]
  691. tch_preds = tch_outs[self.key]
  692. stu_shrink_maps = stu_preds[:, 0, :, :]
  693. stu_binary_maps = stu_preds[:, 2, :, :]
  694. # dilation to teacher prediction
  695. dilation_w = np.array([[1, 1], [1, 1]])
  696. th_shrink_maps = tch_preds[:, 0, :, :]
  697. if hasattr(paddle.Tensor, "contiguous"):
  698. th_shrink_maps = th_shrink_maps.contiguous()
  699. th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
  700. dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
  701. for i in range(th_shrink_maps.shape[0]):
  702. dilate_maps[i] = cv2.dilate(
  703. th_shrink_maps[i, :, :].astype(np.uint8), dilation_w
  704. )
  705. th_shrink_maps = paddle.to_tensor(dilate_maps)
  706. (
  707. label_threshold_map,
  708. label_threshold_mask,
  709. label_shrink_map,
  710. label_shrink_mask,
  711. ) = batch[1:]
  712. # calculate the shrink map loss
  713. bce_loss = self.alpha * self.bce_loss(
  714. stu_shrink_maps, th_shrink_maps, label_shrink_mask
  715. )
  716. loss_binary_maps = self.dice_loss(
  717. stu_binary_maps, th_shrink_maps, label_shrink_mask
  718. )
  719. # k = f"{self.name}_{pair[0]}_{pair[1]}"
  720. k = "{}_{}_{}".format(self.name, pair[0], pair[1])
  721. loss_dict[k] = bce_loss + loss_binary_maps
  722. loss_dict = _sum_loss(loss_dict)
  723. return loss_dict
  724. class DistillationDistanceLoss(DistanceLoss):
  725. """ """
  726. def __init__(
  727. self, mode="l2", model_name_pairs=[], key=None, name="loss_distance", **kargs
  728. ):
  729. super().__init__(mode=mode, **kargs)
  730. assert isinstance(model_name_pairs, list)
  731. self.key = key
  732. self.model_name_pairs = model_name_pairs
  733. self.name = name + "_l2"
  734. def forward(self, predicts, batch):
  735. loss_dict = dict()
  736. for idx, pair in enumerate(self.model_name_pairs):
  737. out1 = predicts[pair[0]]
  738. out2 = predicts[pair[1]]
  739. if self.key is not None:
  740. out1 = out1[self.key]
  741. out2 = out2[self.key]
  742. loss = super().forward(out1, out2)
  743. if isinstance(loss, dict):
  744. for key in loss:
  745. loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key]
  746. else:
  747. loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)] = loss
  748. return loss_dict
  749. class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
  750. def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"):
  751. super().__init__(num_classes=num_classes)
  752. self.model_name_list = model_name_list
  753. self.key = key
  754. self.name = name
  755. def forward(self, predicts, batch):
  756. loss_dict = dict()
  757. for idx, model_name in enumerate(self.model_name_list):
  758. out = predicts[model_name]
  759. if self.key is not None:
  760. out = out[self.key]
  761. loss = super().forward(out, batch)
  762. loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
  763. return loss_dict
  764. class DistillationLossFromOutput(LossFromOutput):
  765. def __init__(
  766. self,
  767. reduction="none",
  768. model_name_list=[],
  769. dist_key=None,
  770. key="loss",
  771. name="loss_re",
  772. ):
  773. super().__init__(key=key, reduction=reduction)
  774. self.model_name_list = model_name_list
  775. self.name = name
  776. self.dist_key = dist_key
  777. def forward(self, predicts, batch):
  778. loss_dict = dict()
  779. for idx, model_name in enumerate(self.model_name_list):
  780. out = predicts[model_name]
  781. if self.dist_key is not None:
  782. out = out[self.dist_key]
  783. loss = super().forward(out, batch)
  784. loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
  785. return loss_dict
  786. class DistillationSERDMLLoss(DMLLoss):
  787. """ """
  788. def __init__(
  789. self,
  790. act="softmax",
  791. use_log=True,
  792. num_classes=7,
  793. model_name_pairs=[],
  794. key=None,
  795. name="loss_dml_ser",
  796. ):
  797. super().__init__(act=act, use_log=use_log)
  798. assert isinstance(model_name_pairs, list)
  799. self.key = key
  800. self.name = name
  801. self.num_classes = num_classes
  802. self.model_name_pairs = model_name_pairs
  803. def forward(self, predicts, batch):
  804. loss_dict = dict()
  805. for idx, pair in enumerate(self.model_name_pairs):
  806. out1 = predicts[pair[0]]
  807. out2 = predicts[pair[1]]
  808. if self.key is not None:
  809. out1 = out1[self.key]
  810. out2 = out2[self.key]
  811. out1 = out1.reshape([-1, out1.shape[-1]])
  812. out2 = out2.reshape([-1, out2.shape[-1]])
  813. attention_mask = batch[2]
  814. if attention_mask is not None:
  815. active_output = (
  816. attention_mask.reshape(
  817. [
  818. -1,
  819. ]
  820. )
  821. == 1
  822. )
  823. out1 = out1[active_output]
  824. out2 = out2[active_output]
  825. loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2)
  826. return loss_dict
  827. class DistillationVQADistanceLoss(DistanceLoss):
  828. def __init__(
  829. self,
  830. mode="l2",
  831. model_name_pairs=[],
  832. key=None,
  833. index=None,
  834. name="loss_distance",
  835. **kargs,
  836. ):
  837. super().__init__(mode=mode, **kargs)
  838. assert isinstance(model_name_pairs, list)
  839. self.key = key
  840. self.index = index
  841. self.model_name_pairs = model_name_pairs
  842. self.name = name + "_l2"
  843. def forward(self, predicts, batch):
  844. loss_dict = dict()
  845. for idx, pair in enumerate(self.model_name_pairs):
  846. out1 = predicts[pair[0]]
  847. out2 = predicts[pair[1]]
  848. attention_mask = batch[2]
  849. if self.key is not None:
  850. out1 = out1[self.key]
  851. out2 = out2[self.key]
  852. if self.index is not None:
  853. out1 = out1[:, self.index, :, :]
  854. out2 = out2[:, self.index, :, :]
  855. if attention_mask is not None:
  856. max_len = attention_mask.shape[-1]
  857. out1 = out1[:, :max_len]
  858. out2 = out2[:, :max_len]
  859. out1 = out1.reshape([-1, out1.shape[-1]])
  860. out2 = out2.reshape([-1, out2.shape[-1]])
  861. if attention_mask is not None:
  862. active_output = (
  863. attention_mask.reshape(
  864. [
  865. -1,
  866. ]
  867. )
  868. == 1
  869. )
  870. out1 = out1[active_output]
  871. out2 = out2[active_output]
  872. loss = super().forward(out1, out2)
  873. if isinstance(loss, dict):
  874. for key in loss:
  875. loss_dict["{}_{}nohu_{}".format(self.name, key, idx)] = loss[key]
  876. else:
  877. loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)] = loss
  878. return loss_dict
  879. class CTCDKDLoss(nn.Layer):
  880. """
  881. KLDivLoss
  882. """
  883. def __init__(self, temperature=0.5, alpha=1.0, beta=1.0):
  884. super().__init__()
  885. self.temperature = temperature
  886. self.alpha = alpha
  887. self.beta = beta
  888. self.eps = 1e-6
  889. self.t = temperature
  890. self.act = nn.Softmax(axis=-1)
  891. self.use_log = True
  892. def kl_loss(self, p1, p2): # predict, label
  893. loss = paddle.multiply(
  894. p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps)
  895. )
  896. bs = loss.shape[0]
  897. loss = paddle.sum(loss) / bs
  898. return loss
  899. def _cat_mask(self, t, mask1, mask2):
  900. t1 = (t * mask1).sum(axis=1, keepdim=True)
  901. t2 = (t * mask2).sum(axis=1, keepdim=True)
  902. rt = paddle.concat([t1, t2], axis=1)
  903. return rt
  904. def multi_label_mask(self, targets):
  905. targets = targets.astype("int32")
  906. res = F.one_hot(targets, num_classes=11465)
  907. mask = paddle.clip(paddle.sum(res, axis=1), 0, 1)
  908. mask[:, 0] = 0 # ignore ctc blank label
  909. return mask
  910. def forward(self, logits_student, logits_teacher, targets, mask=None):
  911. gt_mask = self.multi_label_mask(targets)
  912. other_mask = paddle.ones_like(gt_mask) - gt_mask
  913. pred_student = F.softmax(logits_student / self.temperature, axis=-1)
  914. pred_teacher = F.softmax(logits_teacher / self.temperature, axis=-1)
  915. # differences with dkd
  916. pred_student = paddle.mean(pred_student, axis=1)
  917. pred_teacher = paddle.mean(pred_teacher, axis=1)
  918. pred_student = self._cat_mask(pred_student, gt_mask, other_mask)
  919. pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask)
  920. # differences with dkd
  921. tckd_loss = self.kl_loss(pred_student, pred_teacher)
  922. gt_mask_ex = paddle.expand_as(gt_mask.unsqueeze(axis=1), logits_teacher)
  923. pred_teacher_part2 = F.softmax(
  924. logits_teacher / self.temperature - 1000.0 * gt_mask_ex, axis=-1
  925. )
  926. pred_student_part2 = F.softmax(
  927. logits_student / self.temperature - 1000.0 * gt_mask_ex, axis=-1
  928. )
  929. # differences with dkd
  930. pred_teacher_part2 = paddle.mean(pred_teacher_part2, axis=1)
  931. pred_student_part2 = paddle.mean(pred_student_part2, axis=1)
  932. # differences with dkd
  933. nckd_loss = self.kl_loss(pred_student_part2, pred_teacher_part2)
  934. loss = self.alpha * tckd_loss + self.beta * nckd_loss
  935. return loss
  936. class KLCTCLogits(nn.Layer):
  937. def __init__(self, weight=1.0, reduction="mean", mode="mean"):
  938. super().__init__()
  939. self.weight = weight
  940. self.reduction = reduction
  941. self.eps = 1e-6
  942. self.t = 0.5
  943. self.act = nn.Softmax(axis=-1)
  944. self.use_log = True
  945. self.mode = mode
  946. self.ctc_dkd_loss = CTCDKDLoss()
  947. def kl_loss(self, p1, p2): # predict, label
  948. loss = paddle.multiply(
  949. p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps)
  950. )
  951. bs = loss.shape[0]
  952. loss = paddle.sum(loss) / bs
  953. return loss
  954. def forward_meanmax(self, stu_out, tea_out):
  955. stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1)
  956. tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1)
  957. loss = self.kl_loss(stu_out, tea_out)
  958. return loss
  959. def forward_meanlog(self, stu_out, tea_out):
  960. stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1)
  961. tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1)
  962. if self.use_log is True:
  963. # for recognition distillation, log is needed for feature map
  964. log_out1 = paddle.log(stu_out)
  965. log_out2 = paddle.log(tea_out)
  966. loss = (
  967. self._kldiv(log_out1, tea_out) + self._kldiv(log_out2, stu_out)
  968. ) / 2.0
  969. return loss
  970. def forward_sum(self, stu_out, tea_out):
  971. stu_out = paddle.sum(F.softmax(stu_out / self.t, axis=-1), axis=1)
  972. tea_out = paddle.sum(F.softmax(tea_out / self.t, axis=-1), axis=1)
  973. stu_out = paddle.log(stu_out)
  974. bs = stu_out.shape[0]
  975. loss = tea_out * (paddle.log(tea_out + self.eps) - stu_out)
  976. loss = paddle.sum(loss, axis=1) / loss.shape[0]
  977. return loss
  978. def _kldiv(self, x, target):
  979. eps = 1.0e-10
  980. loss = target * (paddle.log(target + eps) - x)
  981. loss = paddle.sum(paddle.mean(loss, axis=1)) / loss.shape[0]
  982. return loss
  983. def forward(self, stu_out, tea_out, targets=None):
  984. if self.mode == "log":
  985. return self.forward_log(stu_out, tea_out)
  986. elif self.mode == "mean":
  987. blank_mask = paddle.ones_like(stu_out)
  988. blank_mask.stop_gradient = True
  989. blank_mask[:, :, 0] = -1
  990. stu_out *= blank_mask
  991. tea_out *= blank_mask
  992. return self.forward_meanmax(stu_out, tea_out)
  993. elif self.mode == "sum":
  994. return self.forward_sum(stu_out, tea_out)
  995. elif self.mode == "meanlog":
  996. blank_mask = paddle.ones_like(stu_out)
  997. blank_mask.stop_gradient = True
  998. blank_mask[:, :, 0] = -1
  999. stu_out *= blank_mask
  1000. tea_out *= blank_mask
  1001. return self.forward_meanlog(stu_out, tea_out)
  1002. elif self.mode == "ctcdkd":
  1003. # ignore ctc blank logits
  1004. blank_mask = paddle.ones_like(stu_out)
  1005. blank_mask.stop_gradient = True
  1006. blank_mask[:, :, 0] = -1
  1007. stu_out *= blank_mask
  1008. tea_out *= blank_mask
  1009. return self.ctc_dkd_loss(stu_out, tea_out, targets)
  1010. else:
  1011. raise ValueError("error!!!!!!")
  1012. def forward_log(self, out1, out2):
  1013. if self.act is not None:
  1014. out1 = self.act(out1) + 1e-10
  1015. out2 = self.act(out2) + 1e-10
  1016. if self.use_log is True:
  1017. # for recognition distillation, log is needed for feature map
  1018. log_out1 = paddle.log(out1)
  1019. log_out2 = paddle.log(out2)
  1020. loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
  1021. return loss
  1022. class DistillCTCLogits(KLCTCLogits):
  1023. def __init__(
  1024. self, model_name_pairs=[], key=None, name="ctc_logits", reduction="mean"
  1025. ):
  1026. super().__init__(reduction=reduction)
  1027. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  1028. self.key = key
  1029. self.name = name
  1030. def _check_model_name_pairs(self, model_name_pairs):
  1031. if not isinstance(model_name_pairs, list):
  1032. return []
  1033. elif isinstance(model_name_pairs[0], list) and isinstance(
  1034. model_name_pairs[0][0], str
  1035. ):
  1036. return model_name_pairs
  1037. else:
  1038. return [model_name_pairs]
  1039. def forward(self, predicts, batch):
  1040. loss_dict = dict()
  1041. for idx, pair in enumerate(self.model_name_pairs):
  1042. out1 = predicts[pair[0]]
  1043. out2 = predicts[pair[1]]
  1044. if self.key is not None:
  1045. out1 = out1[self.key]["ctc"]
  1046. out2 = out2[self.key]["ctc"]
  1047. ctc_label = batch[1]
  1048. loss = super().forward(out1, out2, ctc_label)
  1049. if isinstance(loss, dict):
  1050. for key in loss:
  1051. loss_dict[
  1052. "{}_{}_{}".format(self.name, self.model_name_pairs, idx)
  1053. ] = loss[key]
  1054. else:
  1055. loss_dict["{}_{}".format(self.name, idx)] = loss
  1056. return loss_dict