rec_postprocess.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import paddle
  17. from paddle.nn import functional as F
  18. import re
  19. import json
  20. class BaseRecLabelDecode(object):
  21. """Convert between text-label and text-index"""
  22. def __init__(self, character_dict_path=None, use_space_char=False):
  23. self.beg_str = "sos"
  24. self.end_str = "eos"
  25. self.reverse = False
  26. self.character_str = []
  27. if character_dict_path is None:
  28. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  29. dict_character = list(self.character_str)
  30. else:
  31. with open(character_dict_path, "rb") as fin:
  32. lines = fin.readlines()
  33. for line in lines:
  34. line = line.decode("utf-8").strip("\n").strip("\r\n")
  35. self.character_str.append(line)
  36. if use_space_char:
  37. self.character_str.append(" ")
  38. dict_character = list(self.character_str)
  39. if "arabic" in character_dict_path:
  40. self.reverse = True
  41. dict_character = self.add_special_char(dict_character)
  42. self.dict = {}
  43. for i, char in enumerate(dict_character):
  44. self.dict[char] = i
  45. self.character = dict_character
  46. def pred_reverse(self, pred):
  47. pred_re = []
  48. c_current = ""
  49. for c in pred:
  50. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  51. if c_current != "":
  52. pred_re.append(c_current)
  53. pred_re.append(c)
  54. c_current = ""
  55. else:
  56. c_current += c
  57. if c_current != "":
  58. pred_re.append(c_current)
  59. return "".join(pred_re[::-1])
  60. def add_special_char(self, dict_character):
  61. return dict_character
  62. def get_word_info(self, text, selection):
  63. """
  64. Group the decoded characters and record the corresponding decoded positions.
  65. Args:
  66. text: the decoded text
  67. selection: the bool array that identifies which columns of features are decoded as non-separated characters
  68. Returns:
  69. word_list: list of the grouped words
  70. word_col_list: list of decoding positions corresponding to each character in the grouped word
  71. state_list: list of marker to identify the type of grouping words, including two types of grouping words:
  72. - 'cn': continuous chinese characters (e.g., 你好啊)
  73. - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
  74. The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
  75. """
  76. state = None
  77. word_content = []
  78. word_col_content = []
  79. word_list = []
  80. word_col_list = []
  81. state_list = []
  82. valid_col = np.where(selection == True)[0]
  83. for c_i, char in enumerate(text):
  84. if "\u4e00" <= char <= "\u9fff":
  85. c_state = "cn"
  86. # Use \w with UNICODE flag to match letters (including accented chars like ä, ö, ü, é, etc.) and digits
  87. # Exclude underscore since \w includes it but we want to treat it as splitter
  88. elif bool(re.search(r"[\w]", char, re.UNICODE)) and char != "_":
  89. c_state = "en&num"
  90. else:
  91. c_state = "splitter"
  92. # Handle apostrophes in French words like "n'êtes"
  93. if char == "'" and state == "en&num":
  94. c_state = "en&num"
  95. if (
  96. char == "."
  97. and state == "en&num"
  98. and c_i + 1 < len(text)
  99. and bool(re.search("[0-9]", text[c_i + 1]))
  100. ): # grouping floating number
  101. c_state = "en&num"
  102. if (
  103. char == "-" and state == "en&num"
  104. ): # grouping word with '-', such as 'state-of-the-art'
  105. c_state = "en&num"
  106. if state == None:
  107. state = c_state
  108. if state != c_state:
  109. if len(word_content) != 0:
  110. word_list.append(word_content)
  111. word_col_list.append(word_col_content)
  112. state_list.append(state)
  113. word_content = []
  114. word_col_content = []
  115. state = c_state
  116. if state != "splitter":
  117. word_content.append(char)
  118. word_col_content.append(valid_col[c_i])
  119. if len(word_content) != 0:
  120. word_list.append(word_content)
  121. word_col_list.append(word_col_content)
  122. state_list.append(state)
  123. return word_list, word_col_list, state_list
  124. def decode(
  125. self,
  126. text_index,
  127. text_prob=None,
  128. is_remove_duplicate=False,
  129. return_word_box=False,
  130. ):
  131. """convert text-index into text-label."""
  132. result_list = []
  133. ignored_tokens = self.get_ignored_tokens()
  134. batch_size = len(text_index)
  135. for batch_idx in range(batch_size):
  136. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  137. if is_remove_duplicate:
  138. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  139. for ignored_token in ignored_tokens:
  140. selection &= text_index[batch_idx] != ignored_token
  141. char_list = [
  142. self.character[text_id] for text_id in text_index[batch_idx][selection]
  143. ]
  144. if text_prob is not None:
  145. conf_list = text_prob[batch_idx][selection]
  146. else:
  147. conf_list = [1] * len(selection)
  148. if len(conf_list) == 0:
  149. conf_list = [0]
  150. text = "".join(char_list)
  151. if self.reverse: # for arabic rec
  152. text = self.pred_reverse(text)
  153. if return_word_box:
  154. word_list, word_col_list, state_list = self.get_word_info(
  155. text, selection
  156. )
  157. result_list.append(
  158. (
  159. text,
  160. np.mean(conf_list).tolist(),
  161. [
  162. len(text_index[batch_idx]),
  163. word_list,
  164. word_col_list,
  165. state_list,
  166. ],
  167. )
  168. )
  169. else:
  170. result_list.append((text, np.mean(conf_list).tolist()))
  171. return result_list
  172. def get_ignored_tokens(self):
  173. return [0] # for ctc blank
  174. class CTCLabelDecode(BaseRecLabelDecode):
  175. """Convert between text-label and text-index"""
  176. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  177. super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
  178. def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
  179. if isinstance(preds, tuple) or isinstance(preds, list):
  180. preds = preds[-1]
  181. if isinstance(preds, paddle.Tensor):
  182. preds = preds.numpy()
  183. preds_idx = preds.argmax(axis=2)
  184. preds_prob = preds.max(axis=2)
  185. text = self.decode(
  186. preds_idx,
  187. preds_prob,
  188. is_remove_duplicate=True,
  189. return_word_box=return_word_box,
  190. )
  191. if return_word_box:
  192. for rec_idx, rec in enumerate(text):
  193. wh_ratio = kwargs["wh_ratio_list"][rec_idx]
  194. max_wh_ratio = kwargs["max_wh_ratio"]
  195. rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
  196. if label is None:
  197. return text
  198. label = self.decode(label)
  199. return text, label
  200. def add_special_char(self, dict_character):
  201. dict_character = ["blank"] + dict_character
  202. return dict_character
  203. class DistillationCTCLabelDecode(CTCLabelDecode):
  204. """
  205. Convert
  206. Convert between text-label and text-index
  207. """
  208. def __init__(
  209. self,
  210. character_dict_path=None,
  211. use_space_char=False,
  212. model_name=["student"],
  213. key=None,
  214. multi_head=False,
  215. **kwargs,
  216. ):
  217. super(DistillationCTCLabelDecode, self).__init__(
  218. character_dict_path, use_space_char
  219. )
  220. if not isinstance(model_name, list):
  221. model_name = [model_name]
  222. self.model_name = model_name
  223. self.key = key
  224. self.multi_head = multi_head
  225. def __call__(self, preds, label=None, *args, **kwargs):
  226. output = dict()
  227. for name in self.model_name:
  228. pred = preds[name]
  229. if self.key is not None:
  230. pred = pred[self.key]
  231. if self.multi_head and isinstance(pred, dict):
  232. pred = pred["ctc"]
  233. output[name] = super().__call__(pred, label=label, *args, **kwargs)
  234. return output
  235. class AttnLabelDecode(BaseRecLabelDecode):
  236. """Convert between text-label and text-index"""
  237. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  238. super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
  239. def add_special_char(self, dict_character):
  240. self.beg_str = "sos"
  241. self.end_str = "eos"
  242. dict_character = dict_character
  243. dict_character = [self.beg_str] + dict_character + [self.end_str]
  244. return dict_character
  245. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  246. """convert text-index into text-label."""
  247. result_list = []
  248. ignored_tokens = self.get_ignored_tokens()
  249. [beg_idx, end_idx] = self.get_ignored_tokens()
  250. batch_size = len(text_index)
  251. for batch_idx in range(batch_size):
  252. char_list = []
  253. conf_list = []
  254. for idx in range(len(text_index[batch_idx])):
  255. if text_index[batch_idx][idx] in ignored_tokens:
  256. continue
  257. if int(text_index[batch_idx][idx]) == int(end_idx):
  258. break
  259. if is_remove_duplicate:
  260. # only for predict
  261. if (
  262. idx > 0
  263. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  264. ):
  265. continue
  266. char_list.append(self.character[int(text_index[batch_idx][idx])])
  267. if text_prob is not None:
  268. conf_list.append(text_prob[batch_idx][idx])
  269. else:
  270. conf_list.append(1)
  271. text = "".join(char_list)
  272. result_list.append((text, np.mean(conf_list).tolist()))
  273. return result_list
  274. def __call__(self, preds, label=None, *args, **kwargs):
  275. """
  276. text = self.decode(text)
  277. if label is None:
  278. return text
  279. else:
  280. label = self.decode(label, is_remove_duplicate=False)
  281. return text, label
  282. """
  283. if isinstance(preds, paddle.Tensor):
  284. preds = preds.numpy()
  285. preds_idx = preds.argmax(axis=2)
  286. preds_prob = preds.max(axis=2)
  287. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  288. if label is None:
  289. return text
  290. label = self.decode(label, is_remove_duplicate=False)
  291. return text, label
  292. def get_ignored_tokens(self):
  293. beg_idx = self.get_beg_end_flag_idx("beg")
  294. end_idx = self.get_beg_end_flag_idx("end")
  295. return [beg_idx, end_idx]
  296. def get_beg_end_flag_idx(self, beg_or_end):
  297. if beg_or_end == "beg":
  298. idx = np.array(self.dict[self.beg_str])
  299. elif beg_or_end == "end":
  300. idx = np.array(self.dict[self.end_str])
  301. else:
  302. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  303. return idx
  304. class RFLLabelDecode(BaseRecLabelDecode):
  305. """Convert between text-label and text-index"""
  306. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  307. super(RFLLabelDecode, self).__init__(character_dict_path, use_space_char)
  308. def add_special_char(self, dict_character):
  309. self.beg_str = "sos"
  310. self.end_str = "eos"
  311. dict_character = dict_character
  312. dict_character = [self.beg_str] + dict_character + [self.end_str]
  313. return dict_character
  314. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  315. """convert text-index into text-label."""
  316. result_list = []
  317. ignored_tokens = self.get_ignored_tokens()
  318. [beg_idx, end_idx] = self.get_ignored_tokens()
  319. batch_size = len(text_index)
  320. for batch_idx in range(batch_size):
  321. char_list = []
  322. conf_list = []
  323. for idx in range(len(text_index[batch_idx])):
  324. if text_index[batch_idx][idx] in ignored_tokens:
  325. continue
  326. if int(text_index[batch_idx][idx]) == int(end_idx):
  327. break
  328. if is_remove_duplicate:
  329. # only for predict
  330. if (
  331. idx > 0
  332. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  333. ):
  334. continue
  335. char_list.append(self.character[int(text_index[batch_idx][idx])])
  336. if text_prob is not None:
  337. conf_list.append(text_prob[batch_idx][idx])
  338. else:
  339. conf_list.append(1)
  340. text = "".join(char_list)
  341. result_list.append((text, np.mean(conf_list).tolist()))
  342. return result_list
  343. def __call__(self, preds, label=None, *args, **kwargs):
  344. # if seq_outputs is not None:
  345. if isinstance(preds, tuple) or isinstance(preds, list):
  346. cnt_outputs, seq_outputs = preds
  347. if isinstance(seq_outputs, paddle.Tensor):
  348. seq_outputs = seq_outputs.numpy()
  349. preds_idx = seq_outputs.argmax(axis=2)
  350. preds_prob = seq_outputs.max(axis=2)
  351. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  352. if label is None:
  353. return text
  354. label = self.decode(label, is_remove_duplicate=False)
  355. return text, label
  356. else:
  357. cnt_outputs = preds
  358. if isinstance(cnt_outputs, paddle.Tensor):
  359. cnt_outputs = cnt_outputs.numpy()
  360. cnt_length = []
  361. for lens in cnt_outputs:
  362. length = round(np.sum(lens))
  363. cnt_length.append(length)
  364. if label is None:
  365. return cnt_length
  366. label = self.decode(label, is_remove_duplicate=False)
  367. length = [len(res[0]) for res in label]
  368. return cnt_length, length
  369. def get_ignored_tokens(self):
  370. beg_idx = self.get_beg_end_flag_idx("beg")
  371. end_idx = self.get_beg_end_flag_idx("end")
  372. return [beg_idx, end_idx]
  373. def get_beg_end_flag_idx(self, beg_or_end):
  374. if beg_or_end == "beg":
  375. idx = np.array(self.dict[self.beg_str])
  376. elif beg_or_end == "end":
  377. idx = np.array(self.dict[self.end_str])
  378. else:
  379. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  380. return idx
  381. class SEEDLabelDecode(BaseRecLabelDecode):
  382. """Convert between text-label and text-index"""
  383. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  384. super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
  385. def add_special_char(self, dict_character):
  386. self.padding_str = "padding"
  387. self.end_str = "eos"
  388. self.unknown = "unknown"
  389. dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
  390. return dict_character
  391. def get_ignored_tokens(self):
  392. end_idx = self.get_beg_end_flag_idx("eos")
  393. return [end_idx]
  394. def get_beg_end_flag_idx(self, beg_or_end):
  395. if beg_or_end == "sos":
  396. idx = np.array(self.dict[self.beg_str])
  397. elif beg_or_end == "eos":
  398. idx = np.array(self.dict[self.end_str])
  399. else:
  400. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  401. return idx
  402. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  403. """convert text-index into text-label."""
  404. result_list = []
  405. [end_idx] = self.get_ignored_tokens()
  406. batch_size = len(text_index)
  407. for batch_idx in range(batch_size):
  408. char_list = []
  409. conf_list = []
  410. for idx in range(len(text_index[batch_idx])):
  411. if int(text_index[batch_idx][idx]) == int(end_idx):
  412. break
  413. if is_remove_duplicate:
  414. # only for predict
  415. if (
  416. idx > 0
  417. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  418. ):
  419. continue
  420. char_list.append(self.character[int(text_index[batch_idx][idx])])
  421. if text_prob is not None:
  422. conf_list.append(text_prob[batch_idx][idx])
  423. else:
  424. conf_list.append(1)
  425. text = "".join(char_list)
  426. result_list.append((text, np.mean(conf_list).tolist()))
  427. return result_list
  428. def __call__(self, preds, label=None, *args, **kwargs):
  429. """
  430. text = self.decode(text)
  431. if label is None:
  432. return text
  433. else:
  434. label = self.decode(label, is_remove_duplicate=False)
  435. return text, label
  436. """
  437. preds_idx = preds["rec_pred"]
  438. if isinstance(preds_idx, paddle.Tensor):
  439. preds_idx = preds_idx.numpy()
  440. if "rec_pred_scores" in preds:
  441. preds_idx = preds["rec_pred"]
  442. preds_prob = preds["rec_pred_scores"]
  443. else:
  444. preds_idx = preds["rec_pred"].argmax(axis=2)
  445. preds_prob = preds["rec_pred"].max(axis=2)
  446. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  447. if label is None:
  448. return text
  449. label = self.decode(label, is_remove_duplicate=False)
  450. return text, label
  451. class SRNLabelDecode(BaseRecLabelDecode):
  452. """Convert between text-label and text-index"""
  453. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  454. super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
  455. self.max_text_length = kwargs.get("max_text_length", 25)
  456. def __call__(self, preds, label=None, *args, **kwargs):
  457. pred = preds["predict"]
  458. char_num = len(self.character_str) + 2
  459. if isinstance(pred, paddle.Tensor):
  460. pred = pred.numpy()
  461. pred = np.reshape(pred, [-1, char_num])
  462. preds_idx = np.argmax(pred, axis=1)
  463. preds_prob = np.max(pred, axis=1)
  464. preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
  465. preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
  466. text = self.decode(preds_idx, preds_prob)
  467. if label is None:
  468. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  469. return text
  470. label = self.decode(label)
  471. return text, label
  472. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  473. """convert text-index into text-label."""
  474. result_list = []
  475. ignored_tokens = self.get_ignored_tokens()
  476. batch_size = len(text_index)
  477. for batch_idx in range(batch_size):
  478. char_list = []
  479. conf_list = []
  480. for idx in range(len(text_index[batch_idx])):
  481. if text_index[batch_idx][idx] in ignored_tokens:
  482. continue
  483. if is_remove_duplicate:
  484. # only for predict
  485. if (
  486. idx > 0
  487. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  488. ):
  489. continue
  490. char_list.append(self.character[int(text_index[batch_idx][idx])])
  491. if text_prob is not None:
  492. conf_list.append(text_prob[batch_idx][idx])
  493. else:
  494. conf_list.append(1)
  495. text = "".join(char_list)
  496. result_list.append((text, np.mean(conf_list).tolist()))
  497. return result_list
  498. def add_special_char(self, dict_character):
  499. dict_character = dict_character + [self.beg_str, self.end_str]
  500. return dict_character
  501. def get_ignored_tokens(self):
  502. beg_idx = self.get_beg_end_flag_idx("beg")
  503. end_idx = self.get_beg_end_flag_idx("end")
  504. return [beg_idx, end_idx]
  505. def get_beg_end_flag_idx(self, beg_or_end):
  506. if beg_or_end == "beg":
  507. idx = np.array(self.dict[self.beg_str])
  508. elif beg_or_end == "end":
  509. idx = np.array(self.dict[self.end_str])
  510. else:
  511. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  512. return idx
  513. class ParseQLabelDecode(BaseRecLabelDecode):
  514. """Convert between text-label and text-index"""
  515. BOS = "[B]"
  516. EOS = "[E]"
  517. PAD = "[P]"
  518. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  519. super(ParseQLabelDecode, self).__init__(character_dict_path, use_space_char)
  520. self.max_text_length = kwargs.get("max_text_length", 25)
  521. def __call__(self, preds, label=None, *args, **kwargs):
  522. if isinstance(preds, dict):
  523. pred = preds["predict"]
  524. else:
  525. pred = preds
  526. char_num = (
  527. len(self.character_str) + 1
  528. ) # We don't predict <bos> nor <pad>, with only addition <eos>
  529. if isinstance(pred, paddle.Tensor):
  530. pred = pred.numpy()
  531. B, L = pred.shape[:2]
  532. pred = np.reshape(pred, [-1, char_num])
  533. preds_idx = np.argmax(pred, axis=1)
  534. preds_prob = np.max(pred, axis=1)
  535. preds_idx = np.reshape(preds_idx, [B, L])
  536. preds_prob = np.reshape(preds_prob, [B, L])
  537. if label is None:
  538. text = self.decode(preds_idx, preds_prob, raw=False)
  539. return text
  540. text = self.decode(preds_idx, preds_prob, raw=False)
  541. label = self.decode(label, None, False)
  542. return text, label
  543. def decode(self, text_index, text_prob=None, raw=False):
  544. """convert text-index into text-label."""
  545. result_list = []
  546. ignored_tokens = self.get_ignored_tokens()
  547. batch_size = len(text_index)
  548. for batch_idx in range(batch_size):
  549. char_list = []
  550. conf_list = []
  551. index = text_index[batch_idx, :]
  552. prob = None
  553. if text_prob is not None:
  554. prob = text_prob[batch_idx, :]
  555. if not raw:
  556. index, prob = self._filter(index, prob)
  557. for idx in range(len(index)):
  558. if index[idx] in ignored_tokens:
  559. continue
  560. char_list.append(self.character[int(index[idx])])
  561. if text_prob is not None:
  562. conf_list.append(prob[idx])
  563. else:
  564. conf_list.append(1)
  565. text = "".join(char_list)
  566. result_list.append((text, np.mean(conf_list).tolist()))
  567. return result_list
  568. def add_special_char(self, dict_character):
  569. dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
  570. return dict_character
  571. def _filter(self, ids, probs=None):
  572. ids = ids.tolist()
  573. try:
  574. eos_idx = ids.index(self.dict[self.EOS])
  575. except ValueError:
  576. eos_idx = len(ids) # Nothing to truncate.
  577. # Truncate after EOS
  578. ids = ids[:eos_idx]
  579. if probs is not None:
  580. probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
  581. return ids, probs
  582. def get_ignored_tokens(self):
  583. return [self.dict[self.BOS], self.dict[self.EOS], self.dict[self.PAD]]
  584. class SARLabelDecode(BaseRecLabelDecode):
  585. """Convert between text-label and text-index"""
  586. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  587. super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
  588. self.rm_symbol = kwargs.get("rm_symbol", False)
  589. def add_special_char(self, dict_character):
  590. beg_end_str = "<BOS/EOS>"
  591. unknown_str = "<UKN>"
  592. padding_str = "<PAD>"
  593. dict_character = dict_character + [unknown_str]
  594. self.unknown_idx = len(dict_character) - 1
  595. dict_character = dict_character + [beg_end_str]
  596. self.start_idx = len(dict_character) - 1
  597. self.end_idx = len(dict_character) - 1
  598. dict_character = dict_character + [padding_str]
  599. self.padding_idx = len(dict_character) - 1
  600. return dict_character
  601. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  602. """convert text-index into text-label."""
  603. result_list = []
  604. ignored_tokens = self.get_ignored_tokens()
  605. batch_size = len(text_index)
  606. for batch_idx in range(batch_size):
  607. char_list = []
  608. conf_list = []
  609. for idx in range(len(text_index[batch_idx])):
  610. if text_index[batch_idx][idx] in ignored_tokens:
  611. continue
  612. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  613. if text_prob is None and idx == 0:
  614. continue
  615. else:
  616. break
  617. if is_remove_duplicate:
  618. # only for predict
  619. if (
  620. idx > 0
  621. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  622. ):
  623. continue
  624. char_list.append(self.character[int(text_index[batch_idx][idx])])
  625. if text_prob is not None:
  626. conf_list.append(text_prob[batch_idx][idx])
  627. else:
  628. conf_list.append(1)
  629. text = "".join(char_list)
  630. if self.rm_symbol:
  631. comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
  632. text = text.lower()
  633. text = comp.sub("", text)
  634. result_list.append((text, np.mean(conf_list).tolist()))
  635. return result_list
  636. def __call__(self, preds, label=None, *args, **kwargs):
  637. if isinstance(preds, paddle.Tensor):
  638. preds = preds.numpy()
  639. preds_idx = preds.argmax(axis=2)
  640. preds_prob = preds.max(axis=2)
  641. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  642. if label is None:
  643. return text
  644. label = self.decode(label, is_remove_duplicate=False)
  645. return text, label
  646. def get_ignored_tokens(self):
  647. return [self.padding_idx]
  648. class SATRNLabelDecode(BaseRecLabelDecode):
  649. """Convert between text-label and text-index"""
  650. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  651. super(SATRNLabelDecode, self).__init__(character_dict_path, use_space_char)
  652. self.rm_symbol = kwargs.get("rm_symbol", False)
  653. def add_special_char(self, dict_character):
  654. beg_end_str = "<BOS/EOS>"
  655. unknown_str = "<UKN>"
  656. padding_str = "<PAD>"
  657. dict_character = dict_character + [unknown_str]
  658. self.unknown_idx = len(dict_character) - 1
  659. dict_character = dict_character + [beg_end_str]
  660. self.start_idx = len(dict_character) - 1
  661. self.end_idx = len(dict_character) - 1
  662. dict_character = dict_character + [padding_str]
  663. self.padding_idx = len(dict_character) - 1
  664. return dict_character
  665. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  666. """convert text-index into text-label."""
  667. result_list = []
  668. ignored_tokens = self.get_ignored_tokens()
  669. batch_size = len(text_index)
  670. for batch_idx in range(batch_size):
  671. char_list = []
  672. conf_list = []
  673. for idx in range(len(text_index[batch_idx])):
  674. if text_index[batch_idx][idx] in ignored_tokens:
  675. continue
  676. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  677. if text_prob is None and idx == 0:
  678. continue
  679. else:
  680. break
  681. if is_remove_duplicate:
  682. # only for predict
  683. if (
  684. idx > 0
  685. and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
  686. ):
  687. continue
  688. char_list.append(self.character[int(text_index[batch_idx][idx])])
  689. if text_prob is not None:
  690. conf_list.append(text_prob[batch_idx][idx])
  691. else:
  692. conf_list.append(1)
  693. text = "".join(char_list)
  694. if self.rm_symbol:
  695. comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
  696. text = text.lower()
  697. text = comp.sub("", text)
  698. result_list.append((text, np.mean(conf_list).tolist()))
  699. return result_list
  700. def __call__(self, preds, label=None, *args, **kwargs):
  701. if isinstance(preds, paddle.Tensor):
  702. preds = preds.numpy()
  703. preds_idx = preds.argmax(axis=2)
  704. preds_prob = preds.max(axis=2)
  705. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  706. if label is None:
  707. return text
  708. label = self.decode(label, is_remove_duplicate=False)
  709. return text, label
  710. def get_ignored_tokens(self):
  711. return [self.padding_idx]
  712. class DistillationSARLabelDecode(SARLabelDecode):
  713. """
  714. Convert
  715. Convert between text-label and text-index
  716. """
  717. def __init__(
  718. self,
  719. character_dict_path=None,
  720. use_space_char=False,
  721. model_name=["student"],
  722. key=None,
  723. multi_head=False,
  724. **kwargs,
  725. ):
  726. super(DistillationSARLabelDecode, self).__init__(
  727. character_dict_path, use_space_char
  728. )
  729. if not isinstance(model_name, list):
  730. model_name = [model_name]
  731. self.model_name = model_name
  732. self.key = key
  733. self.multi_head = multi_head
  734. def __call__(self, preds, label=None, *args, **kwargs):
  735. output = dict()
  736. for name in self.model_name:
  737. pred = preds[name]
  738. if self.key is not None:
  739. pred = pred[self.key]
  740. if self.multi_head and isinstance(pred, dict):
  741. pred = pred["sar"]
  742. output[name] = super().__call__(pred, label=label, *args, **kwargs)
  743. return output
  744. class PRENLabelDecode(BaseRecLabelDecode):
  745. """Convert between text-label and text-index"""
  746. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  747. super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
  748. def add_special_char(self, dict_character):
  749. padding_str = "<PAD>" # 0
  750. end_str = "<EOS>" # 1
  751. unknown_str = "<UNK>" # 2
  752. dict_character = [padding_str, end_str, unknown_str] + dict_character
  753. self.padding_idx = 0
  754. self.end_idx = 1
  755. self.unknown_idx = 2
  756. return dict_character
  757. def decode(self, text_index, text_prob=None):
  758. """convert text-index into text-label."""
  759. result_list = []
  760. batch_size = len(text_index)
  761. for batch_idx in range(batch_size):
  762. char_list = []
  763. conf_list = []
  764. for idx in range(len(text_index[batch_idx])):
  765. if text_index[batch_idx][idx] == self.end_idx:
  766. break
  767. if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
  768. continue
  769. char_list.append(self.character[int(text_index[batch_idx][idx])])
  770. if text_prob is not None:
  771. conf_list.append(text_prob[batch_idx][idx])
  772. else:
  773. conf_list.append(1)
  774. text = "".join(char_list)
  775. if len(text) > 0:
  776. result_list.append((text, np.mean(conf_list).tolist()))
  777. else:
  778. # here confidence of empty recog result is 1
  779. result_list.append(("", 1))
  780. return result_list
  781. def __call__(self, preds, label=None, *args, **kwargs):
  782. if isinstance(preds, paddle.Tensor):
  783. preds = preds.numpy()
  784. preds_idx = preds.argmax(axis=2)
  785. preds_prob = preds.max(axis=2)
  786. text = self.decode(preds_idx, preds_prob)
  787. if label is None:
  788. return text
  789. label = self.decode(label)
  790. return text, label
  791. class NRTRLabelDecode(BaseRecLabelDecode):
  792. """Convert between text-label and text-index"""
  793. def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
  794. super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
  795. def __call__(self, preds, label=None, *args, **kwargs):
  796. if len(preds) == 2:
  797. preds_id = preds[0]
  798. preds_prob = preds[1]
  799. if isinstance(preds_id, paddle.Tensor):
  800. preds_id = preds_id.numpy()
  801. if isinstance(preds_prob, paddle.Tensor):
  802. preds_prob = preds_prob.numpy()
  803. if preds_id[0][0] == 2:
  804. preds_idx = preds_id[:, 1:]
  805. preds_prob = preds_prob[:, 1:]
  806. else:
  807. preds_idx = preds_id
  808. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  809. if label is None:
  810. return text
  811. label = self.decode(label[:, 1:])
  812. else:
  813. if isinstance(preds, paddle.Tensor):
  814. preds = preds.numpy()
  815. preds_idx = preds.argmax(axis=2)
  816. preds_prob = preds.max(axis=2)
  817. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  818. if label is None:
  819. return text
  820. label = self.decode(label[:, 1:])
  821. return text, label
  822. def add_special_char(self, dict_character):
  823. dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
  824. return dict_character
  825. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  826. """convert text-index into text-label."""
  827. result_list = []
  828. batch_size = len(text_index)
  829. for batch_idx in range(batch_size):
  830. char_list = []
  831. conf_list = []
  832. for idx in range(len(text_index[batch_idx])):
  833. try:
  834. char_idx = self.character[int(text_index[batch_idx][idx])]
  835. except:
  836. continue
  837. if char_idx == "</s>": # end
  838. break
  839. char_list.append(char_idx)
  840. if text_prob is not None:
  841. conf_list.append(text_prob[batch_idx][idx])
  842. else:
  843. conf_list.append(1)
  844. text = "".join(char_list)
  845. result_list.append((text, np.mean(conf_list).tolist()))
  846. return result_list
  847. class ViTSTRLabelDecode(NRTRLabelDecode):
  848. """Convert between text-label and text-index"""
  849. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  850. super(ViTSTRLabelDecode, self).__init__(character_dict_path, use_space_char)
  851. def __call__(self, preds, label=None, *args, **kwargs):
  852. if isinstance(preds, paddle.Tensor):
  853. preds = preds[:, 1:].numpy()
  854. else:
  855. preds = preds[:, 1:]
  856. preds_idx = preds.argmax(axis=2)
  857. preds_prob = preds.max(axis=2)
  858. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  859. if label is None:
  860. return text
  861. label = self.decode(label[:, 1:])
  862. return text, label
  863. def add_special_char(self, dict_character):
  864. dict_character = ["<s>", "</s>"] + dict_character
  865. return dict_character
  866. class ABINetLabelDecode(NRTRLabelDecode):
  867. """Convert between text-label and text-index"""
  868. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  869. super(ABINetLabelDecode, self).__init__(character_dict_path, use_space_char)
  870. def __call__(self, preds, label=None, *args, **kwargs):
  871. if isinstance(preds, dict):
  872. preds = preds["align"][-1].numpy()
  873. elif isinstance(preds, paddle.Tensor):
  874. preds = preds.numpy()
  875. else:
  876. preds = preds
  877. preds_idx = preds.argmax(axis=2)
  878. preds_prob = preds.max(axis=2)
  879. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  880. if label is None:
  881. return text
  882. label = self.decode(label)
  883. return text, label
  884. def add_special_char(self, dict_character):
  885. dict_character = ["</s>"] + dict_character
  886. return dict_character
  887. class SPINLabelDecode(AttnLabelDecode):
  888. """Convert between text-label and text-index"""
  889. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  890. super(SPINLabelDecode, self).__init__(character_dict_path, use_space_char)
  891. def add_special_char(self, dict_character):
  892. self.beg_str = "sos"
  893. self.end_str = "eos"
  894. dict_character = dict_character
  895. dict_character = [self.beg_str] + [self.end_str] + dict_character
  896. return dict_character
  897. class VLLabelDecode(BaseRecLabelDecode):
  898. """Convert between text-label and text-index"""
  899. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  900. super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
  901. self.max_text_length = kwargs.get("max_text_length", 25)
  902. self.nclass = len(self.character) + 1
  903. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  904. """convert text-index into text-label."""
  905. result_list = []
  906. ignored_tokens = self.get_ignored_tokens()
  907. batch_size = len(text_index)
  908. for batch_idx in range(batch_size):
  909. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  910. if is_remove_duplicate:
  911. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  912. for ignored_token in ignored_tokens:
  913. selection &= text_index[batch_idx] != ignored_token
  914. char_list = [
  915. self.character[text_id - 1]
  916. for text_id in text_index[batch_idx][selection]
  917. ]
  918. if text_prob is not None:
  919. conf_list = text_prob[batch_idx][selection]
  920. else:
  921. conf_list = [1] * len(selection)
  922. if len(conf_list) == 0:
  923. conf_list = [0]
  924. text = "".join(char_list)
  925. result_list.append((text, np.mean(conf_list).tolist()))
  926. return result_list
  927. def __call__(self, preds, label=None, length=None, *args, **kwargs):
  928. if len(preds) == 2: # eval mode
  929. text_pre, x = preds
  930. b = text_pre.shape[1]
  931. lenText = self.max_text_length
  932. nsteps = self.max_text_length
  933. if not isinstance(text_pre, paddle.Tensor):
  934. text_pre = paddle.to_tensor(text_pre, dtype="float32")
  935. out_res = paddle.zeros(shape=[lenText, b, self.nclass], dtype=x.dtype)
  936. out_length = paddle.zeros(shape=[b], dtype=x.dtype)
  937. now_step = 0
  938. for _ in range(nsteps):
  939. if 0 in out_length and now_step < nsteps:
  940. tmp_result = text_pre[now_step, :, :]
  941. out_res[now_step] = tmp_result
  942. tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
  943. for j in range(b):
  944. if out_length[j] == 0 and tmp_result[j] == 0:
  945. out_length[j] = now_step + 1
  946. now_step += 1
  947. for j in range(0, b):
  948. if int(out_length[j]) == 0:
  949. out_length[j] = nsteps
  950. start = 0
  951. output = paddle.zeros(
  952. shape=[int(out_length.sum()), self.nclass], dtype=x.dtype
  953. )
  954. for i in range(0, b):
  955. cur_length = int(out_length[i])
  956. output[start : start + cur_length] = out_res[0:cur_length, i, :]
  957. start += cur_length
  958. net_out = output
  959. length = out_length
  960. else: # train mode
  961. net_out = preds[0]
  962. length = length
  963. net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
  964. text = []
  965. if not isinstance(net_out, paddle.Tensor):
  966. net_out = paddle.to_tensor(net_out, dtype="float32")
  967. net_out = F.softmax(net_out, axis=1)
  968. for i in range(0, length.shape[0]):
  969. if i == 0:
  970. start_idx = 0
  971. end_idx = int(length[i])
  972. else:
  973. start_idx = int(length[:i].sum())
  974. end_idx = int(length[:i].sum() + length[i])
  975. preds_idx = net_out[start_idx:end_idx].topk(1)[1][:, 0].tolist()
  976. preds_text = "".join(
  977. [
  978. (
  979. self.character[idx - 1]
  980. if idx > 0 and idx <= len(self.character)
  981. else ""
  982. )
  983. for idx in preds_idx
  984. ]
  985. )
  986. preds_prob = net_out[start_idx:end_idx].topk(1)[0][:, 0]
  987. preds_prob = paddle.exp(
  988. paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)
  989. )
  990. text.append((preds_text, float(preds_prob)))
  991. if label is None:
  992. return text
  993. label = self.decode(label)
  994. return text, label
  995. class CANLabelDecode(BaseRecLabelDecode):
  996. """Convert between latex-symbol and symbol-index"""
  997. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  998. super(CANLabelDecode, self).__init__(character_dict_path, use_space_char)
  999. def decode(self, text_index, preds_prob=None):
  1000. result_list = []
  1001. batch_size = len(text_index)
  1002. for batch_idx in range(batch_size):
  1003. seq_end = text_index[batch_idx].argmin(0)
  1004. idx_list = text_index[batch_idx][:seq_end].tolist()
  1005. symbol_list = [self.character[idx] for idx in idx_list]
  1006. probs = []
  1007. if preds_prob is not None:
  1008. probs = preds_prob[batch_idx][: len(symbol_list)].tolist()
  1009. result_list.append([" ".join(symbol_list), probs])
  1010. return result_list
  1011. def __call__(self, preds, label=None, *args, **kwargs):
  1012. pred_prob, _, _, _ = preds
  1013. preds_idx = pred_prob.argmax(axis=2)
  1014. text = self.decode(preds_idx)
  1015. if label is None:
  1016. return text
  1017. label = self.decode(label)
  1018. return text, label
  1019. class CPPDLabelDecode(NRTRLabelDecode):
  1020. """Convert between text-label and text-index"""
  1021. def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
  1022. super(CPPDLabelDecode, self).__init__(character_dict_path, use_space_char)
  1023. def __call__(self, preds, label=None, *args, **kwargs):
  1024. if isinstance(preds, tuple):
  1025. if isinstance(preds[-1], dict):
  1026. preds = preds[-1]["align"][-1].numpy()
  1027. else:
  1028. preds = preds[-1].numpy()
  1029. if isinstance(preds, paddle.Tensor):
  1030. preds = preds.numpy()
  1031. else:
  1032. preds = preds
  1033. preds_idx = preds.argmax(axis=2)
  1034. preds_prob = preds.max(axis=2)
  1035. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  1036. if label is None:
  1037. return text
  1038. label = self.decode(label)
  1039. return text, label
  1040. def add_special_char(self, dict_character):
  1041. dict_character = ["</s>"] + dict_character
  1042. return dict_character
  1043. class LaTeXOCRDecode(object):
  1044. """Convert between latex-symbol and symbol-index"""
  1045. def __init__(self, rec_char_dict_path, **kwargs):
  1046. # Set the TOKENIZERS_PARALLELISM environment variable to 'false' to suppress
  1047. # the warning: "The current process just got forked, Disabling parallelism to avoid deadlocks..
  1048. # To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)" from tokenizers
  1049. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  1050. from tokenizers import Tokenizer as TokenizerFast
  1051. super(LaTeXOCRDecode, self).__init__()
  1052. self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
  1053. def post_process(self, s):
  1054. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  1055. letter = "[a-zA-Z]"
  1056. noletter = "[\W_^\d]"
  1057. names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
  1058. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  1059. news = s
  1060. while True:
  1061. s = news
  1062. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  1063. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  1064. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  1065. if news == s:
  1066. break
  1067. return s
  1068. def decode(self, tokens):
  1069. if len(tokens.shape) == 1:
  1070. tokens = tokens[None, :]
  1071. dec = [self.tokenizer.decode(tok) for tok in tokens]
  1072. dec_str_list = [
  1073. "".join(detok.split(" "))
  1074. .replace("Ġ", " ")
  1075. .replace("[EOS]", "")
  1076. .replace("[BOS]", "")
  1077. .replace("[PAD]", "")
  1078. .strip()
  1079. for detok in dec
  1080. ]
  1081. return [self.post_process(dec_str) for dec_str in dec_str_list]
  1082. def __call__(self, preds, label=None, mode="eval", *args, **kwargs):
  1083. if mode == "train":
  1084. preds_idx = np.array(preds.argmax(axis=2))
  1085. text = self.decode(preds_idx)
  1086. else:
  1087. text = self.decode(np.array(preds))
  1088. if label is None:
  1089. return text
  1090. label = self.decode(np.array(label))
  1091. return text, label
  1092. class UniMERNetDecode(object):
  1093. SPECIAL_TOKENS_ATTRIBUTES = [
  1094. "bos_token",
  1095. "eos_token",
  1096. "unk_token",
  1097. "sep_token",
  1098. "pad_token",
  1099. "cls_token",
  1100. "mask_token",
  1101. "additional_special_tokens",
  1102. ]
  1103. def __init__(
  1104. self,
  1105. rec_char_dict_path,
  1106. is_infer=False,
  1107. **kwargs,
  1108. ):
  1109. # Set the TOKENIZERS_PARALLELISM environment variable to 'false' to suppress
  1110. # the warning: "The current process just got forked, Disabling parallelism to avoid deadlocks..
  1111. # To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)" from tokenizers
  1112. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  1113. from tokenizers import Tokenizer as TokenizerFast
  1114. from tokenizers import AddedToken
  1115. self.is_infer = is_infer
  1116. self._unk_token = "<unk>"
  1117. self._bos_token = "<s>"
  1118. self._eos_token = "</s>"
  1119. self._pad_token = "<pad>"
  1120. self._sep_token = None
  1121. self._cls_token = None
  1122. self._mask_token = None
  1123. self._additional_special_tokens = []
  1124. self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
  1125. self.max_seq_len = 2048
  1126. self.pad_token_id = 1
  1127. self.bos_token_id = 0
  1128. self.eos_token_id = 2
  1129. self.padding_side = "right"
  1130. self.pad_token_id = 1
  1131. self.pad_token = "<pad>"
  1132. self.pad_token_type_id = 0
  1133. self.pad_to_multiple_of = None
  1134. fast_tokenizer_file = os.path.join(rec_char_dict_path, "tokenizer.json")
  1135. tokenizer_config_file = os.path.join(
  1136. rec_char_dict_path, "tokenizer_config.json"
  1137. )
  1138. self.tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
  1139. added_tokens_decoder = {}
  1140. added_tokens_map = {}
  1141. if tokenizer_config_file is not None:
  1142. with open(
  1143. tokenizer_config_file, encoding="utf-8"
  1144. ) as tokenizer_config_handle:
  1145. init_kwargs = json.load(tokenizer_config_handle)
  1146. if "added_tokens_decoder" in init_kwargs:
  1147. for idx, token in init_kwargs["added_tokens_decoder"].items():
  1148. if isinstance(token, dict):
  1149. token = AddedToken(**token)
  1150. if isinstance(token, AddedToken):
  1151. added_tokens_decoder[int(idx)] = token
  1152. added_tokens_map[str(token)] = token
  1153. else:
  1154. raise ValueError(
  1155. f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
  1156. )
  1157. init_kwargs["added_tokens_decoder"] = added_tokens_decoder
  1158. added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
  1159. tokens_to_add = [
  1160. token
  1161. for index, token in sorted(
  1162. added_tokens_decoder.items(), key=lambda x: x[0]
  1163. )
  1164. if token not in added_tokens_decoder
  1165. ]
  1166. added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
  1167. encoder = list(added_tokens_encoder.keys()) + [
  1168. str(token) for token in tokens_to_add
  1169. ]
  1170. tokens_to_add += [
  1171. token
  1172. for token in self.all_special_tokens_extended
  1173. if token not in encoder and token not in tokens_to_add
  1174. ]
  1175. if len(tokens_to_add) > 0:
  1176. is_last_special = None
  1177. tokens = []
  1178. special_tokens = self.all_special_tokens
  1179. for token in tokens_to_add:
  1180. is_special = (
  1181. (token.special or str(token) in special_tokens)
  1182. if isinstance(token, AddedToken)
  1183. else str(token) in special_tokens
  1184. )
  1185. if is_last_special is None or is_last_special == is_special:
  1186. tokens.append(token)
  1187. else:
  1188. self._add_tokens(tokens, special_tokens=is_last_special)
  1189. tokens = [token]
  1190. is_last_special = is_special
  1191. if tokens:
  1192. self._add_tokens(tokens, special_tokens=is_last_special)
  1193. def _add_tokens(self, new_tokens, special_tokens=False) -> int:
  1194. if special_tokens:
  1195. return self.tokenizer.add_special_tokens(new_tokens)
  1196. return self.tokenizer.add_tokens(new_tokens)
  1197. def added_tokens_encoder(self, added_tokens_decoder):
  1198. return {
  1199. k.content: v
  1200. for v, k in sorted(added_tokens_decoder.items(), key=lambda item: item[0])
  1201. }
  1202. @property
  1203. def all_special_tokens(self):
  1204. all_toks = [str(s) for s in self.all_special_tokens_extended]
  1205. return all_toks
  1206. @property
  1207. def all_special_tokens_extended(self):
  1208. all_tokens = []
  1209. seen = set()
  1210. for value in self.special_tokens_map_extended.values():
  1211. if isinstance(value, (list, tuple)):
  1212. tokens_to_add = [token for token in value if str(token) not in seen]
  1213. else:
  1214. tokens_to_add = [value] if str(value) not in seen else []
  1215. seen.update(map(str, tokens_to_add))
  1216. all_tokens.extend(tokens_to_add)
  1217. return all_tokens
  1218. @property
  1219. def special_tokens_map_extended(self):
  1220. set_attr = {}
  1221. for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
  1222. attr_value = getattr(self, "_" + attr)
  1223. if attr_value:
  1224. set_attr[attr] = attr_value
  1225. return set_attr
  1226. def convert_ids_to_tokens(self, ids, skip_special_tokens: bool = False):
  1227. if isinstance(ids, int):
  1228. return self.tokenizer.id_to_token(ids)
  1229. tokens = []
  1230. for index in ids:
  1231. index = int(index)
  1232. if skip_special_tokens and index in self.all_special_ids:
  1233. continue
  1234. tokens.append(self.tokenizer.id_to_token(index))
  1235. return tokens
  1236. def detokenize(self, tokens):
  1237. self.tokenizer.bos_token = "<s>"
  1238. self.tokenizer.eos_token = "</s>"
  1239. self.tokenizer.pad_token = "<pad>"
  1240. toks = [self.convert_ids_to_tokens(tok) for tok in tokens]
  1241. for b in range(len(toks)):
  1242. for i in reversed(range(len(toks[b]))):
  1243. if toks[b][i] is None:
  1244. toks[b][i] = ""
  1245. toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
  1246. if toks[b][i] in (
  1247. [
  1248. self.tokenizer.bos_token,
  1249. self.tokenizer.eos_token,
  1250. self.tokenizer.pad_token,
  1251. ]
  1252. ):
  1253. del toks[b][i]
  1254. return toks
  1255. def token2str(self, token_ids) -> list:
  1256. generated_text = []
  1257. for tok_id in token_ids:
  1258. end_idx = np.argwhere(tok_id == 2)
  1259. if len(end_idx) > 0:
  1260. end_idx = int(end_idx[0][0])
  1261. tok_id = tok_id[: end_idx + 1]
  1262. generated_text.append(
  1263. self.tokenizer.decode(tok_id, skip_special_tokens=True)
  1264. )
  1265. generated_text = [self.post_process(text) for text in generated_text]
  1266. return generated_text
  1267. def normalize_infer(self, s: str) -> str:
  1268. """Normalizes a string by removing unnecessary spaces.
  1269. Args:
  1270. s (str): String to normalize.
  1271. Returns:
  1272. str: Normalized string.
  1273. """
  1274. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  1275. letter = "[a-zA-Z]"
  1276. noletter = "[\W_^\d]"
  1277. names = []
  1278. for x in re.findall(text_reg, s):
  1279. pattern = r"\\[a-zA-Z]+"
  1280. pattern = r"(\\[a-zA-Z]+)\s(?=\w)|\\[a-zA-Z]+\s(?=})"
  1281. matches = re.findall(pattern, x[0])
  1282. for m in matches:
  1283. if (
  1284. m
  1285. not in [
  1286. "\\operatorname",
  1287. "\\mathrm",
  1288. "\\text",
  1289. "\\mathbf",
  1290. ]
  1291. and m.strip() != ""
  1292. ):
  1293. s = s.replace(m, m + "XXXXXXX")
  1294. s = s.replace(" ", "")
  1295. names.append(s)
  1296. if len(names) > 0:
  1297. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  1298. news = s
  1299. while True:
  1300. s = news
  1301. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  1302. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  1303. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  1304. if news == s:
  1305. break
  1306. return s.replace("XXXXXXX", " ")
  1307. def remove_chinese_text_wrapping(self, formula):
  1308. pattern = re.compile(r"\\text\s*{\s*([^}]*?[\u4e00-\u9fff]+[^}]*?)\s*}")
  1309. def replacer(match):
  1310. return match.group(1)
  1311. replaced_formula = pattern.sub(replacer, formula)
  1312. return replaced_formula.replace('"', "")
  1313. def normalize(self, s):
  1314. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  1315. letter = "[a-zA-Z]"
  1316. noletter = "[\W_^\d]"
  1317. names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
  1318. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  1319. news = s
  1320. while True:
  1321. s = news
  1322. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  1323. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  1324. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  1325. if news == s:
  1326. break
  1327. return s
  1328. def post_process(self, text: str) -> str:
  1329. """Post-processes a string by fixing text and normalizing it.
  1330. Args:
  1331. text (str): String to post-process.
  1332. Returns:
  1333. str: Post-processed string.
  1334. """
  1335. from ftfy import fix_text
  1336. if self.is_infer:
  1337. text = self.remove_chinese_text_wrapping(text)
  1338. text = fix_text(text)
  1339. text = self.normalize_infer(text)
  1340. else:
  1341. text = fix_text(text)
  1342. text = self.normalize(text)
  1343. return text
  1344. def __call__(self, preds, label=None, mode="eval", *args, **kwargs):
  1345. if mode == "train":
  1346. preds_idx = np.array(preds.argmax(axis=2))
  1347. text = self.token2str(preds_idx)
  1348. else:
  1349. text = self.token2str(np.array(preds))
  1350. if label is None:
  1351. return text
  1352. label = self.token2str(np.array(label))
  1353. return text, label