rec_postprocess.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. import numpy as np
  2. # import paddle
  3. paddle = None
  4. # from paddle.nn import functional as F
  5. import re
  6. class BaseRecLabelDecode(object):
  7. """Convert between text-label and text-index"""
  8. def __init__(self, character_dict_path=None, use_space_char=False):
  9. self.beg_str = "sos"
  10. self.end_str = "eos"
  11. self.reverse = False
  12. self.character_str = []
  13. if character_dict_path is None:
  14. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  15. dict_character = list(self.character_str)
  16. else:
  17. with open(character_dict_path, "rb") as fin:
  18. lines = fin.readlines()
  19. for line in lines:
  20. line = line.decode("utf-8").strip("\n").strip("\r\n")
  21. self.character_str.append(line)
  22. if use_space_char:
  23. self.character_str.append(" ")
  24. dict_character = list(self.character_str)
  25. if "arabic" in character_dict_path:
  26. self.reverse = True
  27. dict_character = self.add_special_char(dict_character)
  28. self.dict = {}
  29. for i, char in enumerate(dict_character):
  30. self.dict[char] = i
  31. self.character = dict_character
  32. def pred_reverse(self, pred):
  33. pred_re = []
  34. c_current = ""
  35. for c in pred:
  36. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  37. if c_current != "":
  38. pred_re.append(c_current)
  39. pred_re.append(c)
  40. c_current = ""
  41. else:
  42. c_current += c
  43. if c_current != "":
  44. pred_re.append(c_current)
  45. return "".join(pred_re[::-1])
  46. def add_special_char(self, dict_character):
  47. return dict_character
  48. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  49. """convert text-index into text-label."""
  50. result_list = []
  51. ignored_tokens = self.get_ignored_tokens()
  52. batch_size = len(text_index)
  53. for batch_idx in range(batch_size):
  54. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  55. if is_remove_duplicate:
  56. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  57. for ignored_token in ignored_tokens:
  58. selection &= text_index[batch_idx] != ignored_token
  59. char_list = [
  60. self.character[text_id] for text_id in text_index[batch_idx][selection]
  61. ]
  62. if text_prob is not None:
  63. conf_list = text_prob[batch_idx][selection]
  64. else:
  65. conf_list = [1] * len(selection)
  66. if len(conf_list) == 0:
  67. conf_list = [0]
  68. text = "".join(char_list)
  69. if self.reverse: # for arabic rec
  70. text = self.pred_reverse(text)
  71. result_list.append((text, np.mean(conf_list).tolist()))
  72. return result_list
  73. def get_ignored_tokens(self):
  74. return [0] # for ctc blank
  75. class CTCLabelDecode(BaseRecLabelDecode):
  76. """Convert between text-label and text-index"""
  77. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  78. super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
  79. def __call__(self, preds, label=None, *args, **kwargs):
  80. if isinstance(preds, tuple) or isinstance(preds, list):
  81. preds = preds[-1]
  82. # if isinstance(preds, paddle.Tensor):
  83. # preds = preds.numpy()
  84. preds_idx = preds.argmax(axis=2)
  85. preds_prob = preds.max(axis=2)
  86. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  87. if label is None:
  88. return text
  89. label = self.decode(label)
  90. return text, label
  91. def add_special_char(self, dict_character):
  92. dict_character = ["blank"] + dict_character
  93. return dict_character
  94. class DistillationCTCLabelDecode(CTCLabelDecode):
  95. """
  96. Convert
  97. Convert between text-label and text-index
  98. """
  99. def __init__(
  100. self,
  101. character_dict_path=None,
  102. use_space_char=False,
  103. model_name=["student"],
  104. key=None,
  105. multi_head=False,
  106. **kwargs
  107. ):
  108. super(DistillationCTCLabelDecode, self).__init__(
  109. character_dict_path, use_space_char
  110. )
  111. if not isinstance(model_name, list):
  112. model_name = [model_name]
  113. self.model_name = model_name
  114. self.key = key
  115. self.multi_head = multi_head
  116. def __call__(self, preds, label=None, *args, **kwargs):
  117. output = dict()
  118. for name in self.model_name:
  119. pred = preds[name]
  120. if self.key is not None:
  121. pred = pred[self.key]
  122. if self.multi_head and isinstance(pred, dict):
  123. pred = pred["ctc"]
  124. output[name] = super().__call__(pred, label=label, *args, **kwargs)
  125. return output
  126. class AttnLabelDecode(BaseRecLabelDecode):
  127. """Convert between text-label and text-index"""
  128. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  129. super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
  130. def add_special_char(self, dict_character):
  131. self.beg_str = "sos"
  132. self.end_str = "eos"
  133. dict_character = dict_character
  134. dict_character = [self.beg_str] + dict_character + [self.end_str]
  135. return dict_character
  136. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  137. """convert text-index into text-label."""
  138. result_list = []
  139. ignored_tokens = self.get_ignored_tokens()
  140. [beg_idx, end_idx] = self.get_ignored_tokens()
  141. batch_size = len(text_index)
  142. for batch_idx in range(batch_size):
  143. char_list = []
  144. conf_list = []
  145. for idx in range(len(text_index[batch_idx])):
  146. if text_index[batch_idx][idx] in ignored_tokens:
  147. continue
  148. if int(text_index[batch_idx][idx]) == int(end_idx):
  149. break
  150. if is_remove_duplicate:
  151. # only for predict
  152. if (
  153. idx > 0
  154. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  155. ):
  156. continue
  157. char_list.append(self.character[int(text_index[batch_idx][idx])])
  158. if text_prob is not None:
  159. conf_list.append(text_prob[batch_idx][idx])
  160. else:
  161. conf_list.append(1)
  162. text = "".join(char_list)
  163. result_list.append((text, np.mean(conf_list).tolist()))
  164. return result_list
  165. def __call__(self, preds, label=None, *args, **kwargs):
  166. """
  167. text = self.decode(text)
  168. if label is None:
  169. return text
  170. else:
  171. label = self.decode(label, is_remove_duplicate=False)
  172. return text, label
  173. """
  174. if isinstance(preds, paddle.Tensor):
  175. preds = preds.numpy()
  176. preds_idx = preds.argmax(axis=2)
  177. preds_prob = preds.max(axis=2)
  178. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  179. if label is None:
  180. return text
  181. label = self.decode(label, is_remove_duplicate=False)
  182. return text, label
  183. def get_ignored_tokens(self):
  184. beg_idx = self.get_beg_end_flag_idx("beg")
  185. end_idx = self.get_beg_end_flag_idx("end")
  186. return [beg_idx, end_idx]
  187. def get_beg_end_flag_idx(self, beg_or_end):
  188. if beg_or_end == "beg":
  189. idx = np.array(self.dict[self.beg_str])
  190. elif beg_or_end == "end":
  191. idx = np.array(self.dict[self.end_str])
  192. else:
  193. assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
  194. return idx
  195. class RFLLabelDecode(BaseRecLabelDecode):
  196. """Convert between text-label and text-index"""
  197. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  198. super(RFLLabelDecode, self).__init__(character_dict_path, use_space_char)
  199. def add_special_char(self, dict_character):
  200. self.beg_str = "sos"
  201. self.end_str = "eos"
  202. dict_character = dict_character
  203. dict_character = [self.beg_str] + dict_character + [self.end_str]
  204. return dict_character
  205. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  206. """convert text-index into text-label."""
  207. result_list = []
  208. ignored_tokens = self.get_ignored_tokens()
  209. [beg_idx, end_idx] = self.get_ignored_tokens()
  210. batch_size = len(text_index)
  211. for batch_idx in range(batch_size):
  212. char_list = []
  213. conf_list = []
  214. for idx in range(len(text_index[batch_idx])):
  215. if text_index[batch_idx][idx] in ignored_tokens:
  216. continue
  217. if int(text_index[batch_idx][idx]) == int(end_idx):
  218. break
  219. if is_remove_duplicate:
  220. # only for predict
  221. if (
  222. idx > 0
  223. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  224. ):
  225. continue
  226. char_list.append(self.character[int(text_index[batch_idx][idx])])
  227. if text_prob is not None:
  228. conf_list.append(text_prob[batch_idx][idx])
  229. else:
  230. conf_list.append(1)
  231. text = "".join(char_list)
  232. result_list.append((text, np.mean(conf_list).tolist()))
  233. return result_list
  234. def __call__(self, preds, label=None, *args, **kwargs):
  235. # if seq_outputs is not None:
  236. if isinstance(preds, tuple) or isinstance(preds, list):
  237. cnt_outputs, seq_outputs = preds
  238. if isinstance(seq_outputs, paddle.Tensor):
  239. seq_outputs = seq_outputs.numpy()
  240. preds_idx = seq_outputs.argmax(axis=2)
  241. preds_prob = seq_outputs.max(axis=2)
  242. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  243. if label is None:
  244. return text
  245. label = self.decode(label, is_remove_duplicate=False)
  246. return text, label
  247. else:
  248. cnt_outputs = preds
  249. if isinstance(cnt_outputs, paddle.Tensor):
  250. cnt_outputs = cnt_outputs.numpy()
  251. cnt_length = []
  252. for lens in cnt_outputs:
  253. length = round(np.sum(lens))
  254. cnt_length.append(length)
  255. if label is None:
  256. return cnt_length
  257. label = self.decode(label, is_remove_duplicate=False)
  258. length = [len(res[0]) for res in label]
  259. return cnt_length, length
  260. def get_ignored_tokens(self):
  261. beg_idx = self.get_beg_end_flag_idx("beg")
  262. end_idx = self.get_beg_end_flag_idx("end")
  263. return [beg_idx, end_idx]
  264. def get_beg_end_flag_idx(self, beg_or_end):
  265. if beg_or_end == "beg":
  266. idx = np.array(self.dict[self.beg_str])
  267. elif beg_or_end == "end":
  268. idx = np.array(self.dict[self.end_str])
  269. else:
  270. assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
  271. return idx
  272. class SEEDLabelDecode(BaseRecLabelDecode):
  273. """Convert between text-label and text-index"""
  274. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  275. super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
  276. def add_special_char(self, dict_character):
  277. self.padding_str = "padding"
  278. self.end_str = "eos"
  279. self.unknown = "unknown"
  280. dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
  281. return dict_character
  282. def get_ignored_tokens(self):
  283. end_idx = self.get_beg_end_flag_idx("eos")
  284. return [end_idx]
  285. def get_beg_end_flag_idx(self, beg_or_end):
  286. if beg_or_end == "sos":
  287. idx = np.array(self.dict[self.beg_str])
  288. elif beg_or_end == "eos":
  289. idx = np.array(self.dict[self.end_str])
  290. else:
  291. assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
  292. return idx
  293. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  294. """convert text-index into text-label."""
  295. result_list = []
  296. [end_idx] = self.get_ignored_tokens()
  297. batch_size = len(text_index)
  298. for batch_idx in range(batch_size):
  299. char_list = []
  300. conf_list = []
  301. for idx in range(len(text_index[batch_idx])):
  302. if int(text_index[batch_idx][idx]) == int(end_idx):
  303. break
  304. if is_remove_duplicate:
  305. # only for predict
  306. if (
  307. idx > 0
  308. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  309. ):
  310. continue
  311. char_list.append(self.character[int(text_index[batch_idx][idx])])
  312. if text_prob is not None:
  313. conf_list.append(text_prob[batch_idx][idx])
  314. else:
  315. conf_list.append(1)
  316. text = "".join(char_list)
  317. result_list.append((text, np.mean(conf_list).tolist()))
  318. return result_list
  319. def __call__(self, preds, label=None, *args, **kwargs):
  320. """
  321. text = self.decode(text)
  322. if label is None:
  323. return text
  324. else:
  325. label = self.decode(label, is_remove_duplicate=False)
  326. return text, label
  327. """
  328. preds_idx = preds["rec_pred"]
  329. if isinstance(preds_idx, paddle.Tensor):
  330. preds_idx = preds_idx.numpy()
  331. if "rec_pred_scores" in preds:
  332. preds_idx = preds["rec_pred"]
  333. preds_prob = preds["rec_pred_scores"]
  334. else:
  335. preds_idx = preds["rec_pred"].argmax(axis=2)
  336. preds_prob = preds["rec_pred"].max(axis=2)
  337. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  338. if label is None:
  339. return text
  340. label = self.decode(label, is_remove_duplicate=False)
  341. return text, label
  342. class SRNLabelDecode(BaseRecLabelDecode):
  343. """Convert between text-label and text-index"""
  344. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  345. super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
  346. self.max_text_length = kwargs.get("max_text_length", 25)
  347. def __call__(self, preds, label=None, *args, **kwargs):
  348. pred = preds["predict"]
  349. char_num = len(self.character_str) + 2
  350. if isinstance(pred, paddle.Tensor):
  351. pred = pred.numpy()
  352. pred = np.reshape(pred, [-1, char_num])
  353. preds_idx = np.argmax(pred, axis=1)
  354. preds_prob = np.max(pred, axis=1)
  355. preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
  356. preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
  357. text = self.decode(preds_idx, preds_prob)
  358. if label is None:
  359. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  360. return text
  361. label = self.decode(label)
  362. return text, label
  363. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  364. """convert text-index into text-label."""
  365. result_list = []
  366. ignored_tokens = self.get_ignored_tokens()
  367. batch_size = len(text_index)
  368. for batch_idx in range(batch_size):
  369. char_list = []
  370. conf_list = []
  371. for idx in range(len(text_index[batch_idx])):
  372. if text_index[batch_idx][idx] in ignored_tokens:
  373. continue
  374. if is_remove_duplicate:
  375. # only for predict
  376. if (
  377. idx > 0
  378. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  379. ):
  380. continue
  381. char_list.append(self.character[int(text_index[batch_idx][idx])])
  382. if text_prob is not None:
  383. conf_list.append(text_prob[batch_idx][idx])
  384. else:
  385. conf_list.append(1)
  386. text = "".join(char_list)
  387. result_list.append((text, np.mean(conf_list).tolist()))
  388. return result_list
  389. def add_special_char(self, dict_character):
  390. dict_character = dict_character + [self.beg_str, self.end_str]
  391. return dict_character
  392. def get_ignored_tokens(self):
  393. beg_idx = self.get_beg_end_flag_idx("beg")
  394. end_idx = self.get_beg_end_flag_idx("end")
  395. return [beg_idx, end_idx]
  396. def get_beg_end_flag_idx(self, beg_or_end):
  397. if beg_or_end == "beg":
  398. idx = np.array(self.dict[self.beg_str])
  399. elif beg_or_end == "end":
  400. idx = np.array(self.dict[self.end_str])
  401. else:
  402. assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
  403. return idx
  404. class SARLabelDecode(BaseRecLabelDecode):
  405. """Convert between text-label and text-index"""
  406. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  407. super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
  408. self.rm_symbol = kwargs.get("rm_symbol", False)
  409. def add_special_char(self, dict_character):
  410. beg_end_str = "<BOS/EOS>"
  411. unknown_str = "<UKN>"
  412. padding_str = "<PAD>"
  413. dict_character = dict_character + [unknown_str]
  414. self.unknown_idx = len(dict_character) - 1
  415. dict_character = dict_character + [beg_end_str]
  416. self.start_idx = len(dict_character) - 1
  417. self.end_idx = len(dict_character) - 1
  418. dict_character = dict_character + [padding_str]
  419. self.padding_idx = len(dict_character) - 1
  420. return dict_character
  421. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  422. """convert text-index into text-label."""
  423. result_list = []
  424. ignored_tokens = self.get_ignored_tokens()
  425. batch_size = len(text_index)
  426. for batch_idx in range(batch_size):
  427. char_list = []
  428. conf_list = []
  429. for idx in range(len(text_index[batch_idx])):
  430. if text_index[batch_idx][idx] in ignored_tokens:
  431. continue
  432. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  433. if text_prob is None and idx == 0:
  434. continue
  435. else:
  436. break
  437. if is_remove_duplicate:
  438. # only for predict
  439. if (
  440. idx > 0
  441. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  442. ):
  443. continue
  444. char_list.append(self.character[int(text_index[batch_idx][idx])])
  445. if text_prob is not None:
  446. conf_list.append(text_prob[batch_idx][idx])
  447. else:
  448. conf_list.append(1)
  449. text = "".join(char_list)
  450. if self.rm_symbol:
  451. comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
  452. text = text.lower()
  453. text = comp.sub("", text)
  454. result_list.append((text, np.mean(conf_list).tolist()))
  455. return result_list
  456. def __call__(self, preds, label=None, *args, **kwargs):
  457. if isinstance(preds, paddle.Tensor):
  458. preds = preds.numpy()
  459. preds_idx = preds.argmax(axis=2)
  460. preds_prob = preds.max(axis=2)
  461. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  462. if label is None:
  463. return text
  464. label = self.decode(label, is_remove_duplicate=False)
  465. return text, label
  466. def get_ignored_tokens(self):
  467. return [self.padding_idx]
  468. class DistillationSARLabelDecode(SARLabelDecode):
  469. """
  470. Convert
  471. Convert between text-label and text-index
  472. """
  473. def __init__(
  474. self,
  475. character_dict_path=None,
  476. use_space_char=False,
  477. model_name=["student"],
  478. key=None,
  479. multi_head=False,
  480. **kwargs
  481. ):
  482. super(DistillationSARLabelDecode, self).__init__(
  483. character_dict_path, use_space_char
  484. )
  485. if not isinstance(model_name, list):
  486. model_name = [model_name]
  487. self.model_name = model_name
  488. self.key = key
  489. self.multi_head = multi_head
  490. def __call__(self, preds, label=None, *args, **kwargs):
  491. output = dict()
  492. for name in self.model_name:
  493. pred = preds[name]
  494. if self.key is not None:
  495. pred = pred[self.key]
  496. if self.multi_head and isinstance(pred, dict):
  497. pred = pred["sar"]
  498. output[name] = super().__call__(pred, label=label, *args, **kwargs)
  499. return output
  500. class PRENLabelDecode(BaseRecLabelDecode):
  501. """Convert between text-label and text-index"""
  502. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  503. super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
  504. def add_special_char(self, dict_character):
  505. padding_str = "<PAD>" # 0
  506. end_str = "<EOS>" # 1
  507. unknown_str = "<UNK>" # 2
  508. dict_character = [padding_str, end_str, unknown_str] + dict_character
  509. self.padding_idx = 0
  510. self.end_idx = 1
  511. self.unknown_idx = 2
  512. return dict_character
  513. def decode(self, text_index, text_prob=None):
  514. """convert text-index into text-label."""
  515. result_list = []
  516. batch_size = len(text_index)
  517. for batch_idx in range(batch_size):
  518. char_list = []
  519. conf_list = []
  520. for idx in range(len(text_index[batch_idx])):
  521. if text_index[batch_idx][idx] == self.end_idx:
  522. break
  523. if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
  524. continue
  525. char_list.append(self.character[int(text_index[batch_idx][idx])])
  526. if text_prob is not None:
  527. conf_list.append(text_prob[batch_idx][idx])
  528. else:
  529. conf_list.append(1)
  530. text = "".join(char_list)
  531. if len(text) > 0:
  532. result_list.append((text, np.mean(conf_list).tolist()))
  533. else:
  534. # here confidence of empty recog result is 1
  535. result_list.append(("", 1))
  536. return result_list
  537. def __call__(self, preds, label=None, *args, **kwargs):
  538. if isinstance(preds, paddle.Tensor):
  539. preds = preds.numpy()
  540. preds_idx = preds.argmax(axis=2)
  541. preds_prob = preds.max(axis=2)
  542. text = self.decode(preds_idx, preds_prob)
  543. if label is None:
  544. return text
  545. label = self.decode(label)
  546. return text, label
  547. class NRTRLabelDecode(BaseRecLabelDecode):
  548. """Convert between text-label and text-index"""
  549. def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
  550. super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
  551. def __call__(self, preds, label=None, *args, **kwargs):
  552. if len(preds) == 2:
  553. preds_id = preds[0]
  554. preds_prob = preds[1]
  555. if isinstance(preds_id, paddle.Tensor):
  556. preds_id = preds_id.numpy()
  557. if isinstance(preds_prob, paddle.Tensor):
  558. preds_prob = preds_prob.numpy()
  559. if preds_id[0][0] == 2:
  560. preds_idx = preds_id[:, 1:]
  561. preds_prob = preds_prob[:, 1:]
  562. else:
  563. preds_idx = preds_id
  564. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  565. if label is None:
  566. return text
  567. label = self.decode(label[:, 1:])
  568. else:
  569. if isinstance(preds, paddle.Tensor):
  570. preds = preds.numpy()
  571. preds_idx = preds.argmax(axis=2)
  572. preds_prob = preds.max(axis=2)
  573. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  574. if label is None:
  575. return text
  576. label = self.decode(label[:, 1:])
  577. return text, label
  578. def add_special_char(self, dict_character):
  579. dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
  580. return dict_character
  581. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  582. """convert text-index into text-label."""
  583. result_list = []
  584. batch_size = len(text_index)
  585. for batch_idx in range(batch_size):
  586. char_list = []
  587. conf_list = []
  588. for idx in range(len(text_index[batch_idx])):
  589. try:
  590. char_idx = self.character[int(text_index[batch_idx][idx])]
  591. except:
  592. continue
  593. if char_idx == "</s>": # end
  594. break
  595. char_list.append(char_idx)
  596. if text_prob is not None:
  597. conf_list.append(text_prob[batch_idx][idx])
  598. else:
  599. conf_list.append(1)
  600. text = "".join(char_list)
  601. result_list.append((text.lower(), np.mean(conf_list).tolist()))
  602. return result_list
  603. class ViTSTRLabelDecode(NRTRLabelDecode):
  604. """Convert between text-label and text-index"""
  605. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  606. super(ViTSTRLabelDecode, self).__init__(character_dict_path, use_space_char)
  607. def __call__(self, preds, label=None, *args, **kwargs):
  608. if isinstance(preds, paddle.Tensor):
  609. preds = preds[:, 1:].numpy()
  610. else:
  611. preds = preds[:, 1:]
  612. preds_idx = preds.argmax(axis=2)
  613. preds_prob = preds.max(axis=2)
  614. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  615. if label is None:
  616. return text
  617. label = self.decode(label[:, 1:])
  618. return text, label
  619. def add_special_char(self, dict_character):
  620. dict_character = ["<s>", "</s>"] + dict_character
  621. return dict_character
  622. class ABINetLabelDecode(NRTRLabelDecode):
  623. """Convert between text-label and text-index"""
  624. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  625. super(ABINetLabelDecode, self).__init__(character_dict_path, use_space_char)
  626. def __call__(self, preds, label=None, *args, **kwargs):
  627. if isinstance(preds, dict):
  628. preds = preds["align"][-1].numpy()
  629. elif isinstance(preds, paddle.Tensor):
  630. preds = preds.numpy()
  631. else:
  632. preds = preds
  633. preds_idx = preds.argmax(axis=2)
  634. preds_prob = preds.max(axis=2)
  635. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  636. if label is None:
  637. return text
  638. label = self.decode(label)
  639. return text, label
  640. def add_special_char(self, dict_character):
  641. dict_character = ["</s>"] + dict_character
  642. return dict_character
  643. class SPINLabelDecode(AttnLabelDecode):
  644. """Convert between text-label and text-index"""
  645. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  646. super(SPINLabelDecode, self).__init__(character_dict_path, use_space_char)
  647. def add_special_char(self, dict_character):
  648. self.beg_str = "sos"
  649. self.end_str = "eos"
  650. dict_character = dict_character
  651. dict_character = [self.beg_str] + [self.end_str] + dict_character
  652. return dict_character
  653. # class VLLabelDecode(BaseRecLabelDecode):
  654. # """ Convert between text-label and text-index """
  655. #
  656. # def __init__(self, character_dict_path=None, use_space_char=False,
  657. # **kwargs):
  658. # super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
  659. # self.max_text_length = kwargs.get('max_text_length', 25)
  660. # self.nclass = len(self.character) + 1
  661. #
  662. # def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  663. # """ convert text-index into text-label. """
  664. # result_list = []
  665. # ignored_tokens = self.get_ignored_tokens()
  666. # batch_size = len(text_index)
  667. # for batch_idx in range(batch_size):
  668. # selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  669. # if is_remove_duplicate:
  670. # selection[1:] = text_index[batch_idx][1:] != text_index[
  671. # batch_idx][:-1]
  672. # for ignored_token in ignored_tokens:
  673. # selection &= text_index[batch_idx] != ignored_token
  674. #
  675. # char_list = [
  676. # self.character[text_id - 1]
  677. # for text_id in text_index[batch_idx][selection]
  678. # ]
  679. # if text_prob is not None:
  680. # conf_list = text_prob[batch_idx][selection]
  681. # else:
  682. # conf_list = [1] * len(selection)
  683. # if len(conf_list) == 0:
  684. # conf_list = [0]
  685. #
  686. # text = ''.join(char_list)
  687. # result_list.append((text, np.mean(conf_list).tolist()))
  688. # return result_list
  689. #
  690. # def __call__(self, preds, label=None, length=None, *args, **kwargs):
  691. # if len(preds) == 2: # eval mode
  692. # text_pre, x = preds
  693. # b = text_pre.shape[1]
  694. # lenText = self.max_text_length
  695. # nsteps = self.max_text_length
  696. #
  697. # if not isinstance(text_pre, paddle.Tensor):
  698. # text_pre = paddle.to_tensor(text_pre, dtype='float32')
  699. #
  700. # out_res = paddle.zeros(
  701. # shape=[lenText, b, self.nclass], dtype=x.dtype)
  702. # out_length = paddle.zeros(shape=[b], dtype=x.dtype)
  703. # now_step = 0
  704. # for _ in range(nsteps):
  705. # if 0 in out_length and now_step < nsteps:
  706. # tmp_result = text_pre[now_step, :, :]
  707. # out_res[now_step] = tmp_result
  708. # tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
  709. # for j in range(b):
  710. # if out_length[j] == 0 and tmp_result[j] == 0:
  711. # out_length[j] = now_step + 1
  712. # now_step += 1
  713. # for j in range(0, b):
  714. # if int(out_length[j]) == 0:
  715. # out_length[j] = nsteps
  716. # start = 0
  717. # output = paddle.zeros(
  718. # shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
  719. # for i in range(0, b):
  720. # cur_length = int(out_length[i])
  721. # output[start:start + cur_length] = out_res[0:cur_length, i, :]
  722. # start += cur_length
  723. # net_out = output
  724. # length = out_length
  725. #
  726. # else: # train mode
  727. # net_out = preds[0]
  728. # length = length
  729. # net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
  730. # text = []
  731. # if not isinstance(net_out, paddle.Tensor):
  732. # net_out = paddle.to_tensor(net_out, dtype='float32')
  733. # net_out = F.softmax(net_out, axis=1)
  734. # for i in range(0, length.shape[0]):
  735. # preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
  736. # ) + length[i])].topk(1)[1][:, 0].tolist()
  737. # preds_text = ''.join([
  738. # self.character[idx - 1]
  739. # if idx > 0 and idx <= len(self.character) else ''
  740. # for idx in preds_idx
  741. # ])
  742. # preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
  743. # ) + length[i])].topk(1)[0][:, 0]
  744. # preds_prob = paddle.exp(
  745. # paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
  746. # text.append((preds_text, preds_prob.numpy()[0]))
  747. # if label is None:
  748. # return text
  749. # label = self.decode(label)
  750. # return text, label
  751. class CANLabelDecode(BaseRecLabelDecode):
  752. """Convert between latex-symbol and symbol-index"""
  753. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  754. super(CANLabelDecode, self).__init__(character_dict_path, use_space_char)
  755. def decode(self, text_index, preds_prob=None):
  756. result_list = []
  757. batch_size = len(text_index)
  758. for batch_idx in range(batch_size):
  759. seq_end = text_index[batch_idx].argmin(0)
  760. idx_list = text_index[batch_idx][:seq_end].tolist()
  761. symbol_list = [self.character[idx] for idx in idx_list]
  762. probs = []
  763. if preds_prob is not None:
  764. probs = preds_prob[batch_idx][: len(symbol_list)].tolist()
  765. result_list.append([" ".join(symbol_list), probs])
  766. return result_list
  767. def __call__(self, preds, label=None, *args, **kwargs):
  768. pred_prob, _, _, _ = preds
  769. preds_idx = pred_prob.argmax(axis=2)
  770. text = self.decode(preds_idx)
  771. if label is None:
  772. return text
  773. label = self.decode(label)
  774. return text, label